#! /usr/bin/env python

"""
check-includes.py <file...>

Checks if the includes are sorted properly and following the "system headers
before local headers" rule.

Ignores what is in #if blocks to avoid false negatives.
"""

import re
import sys

def exclude_if_blocks(lines):
    '''Removes lines from #if ... #endif blocks.'''
    level = 0
    for l in lines:
        if l.startswith('#if'):
            level += 1
        elif l.startswith('#endif'):
            level -= 1
        elif level == 0:
            yield l

def filter_includes(lines):
    '''Removes lines that are not #include and keeps only the file part.'''
    for l in lines:
        if l.startswith('#include'):
            if 'NOLINT' not in l:
                yield l.split(' ')[1]

class IncludeFileSorter(object):
    def __init__(self, path):
        self.path = path

    def __lt__(self, other):
        '''Sorting function for include files.

        * System headers go before local headers (check the first character -
          if it's different, then the one starting with " is the 'larger').
        * Then, iterate on all the path components:
          * If they are equal, try to continue to the next path component.
          * If not, return whether the path component are smaller/larger.
        * Paths with less components should go first, so after iterating, check
          whether one path still has some / in it.
        '''
        a, b = self.path, other.path
        if a[0] != b[0]:
            return False if a[0] == '"' else True
        a, b = a[1:-1].lower(), b[1:-1].lower()
        while '/' in a and '/' in b:
            ca, a = a.split('/', 1)
            cb, b = b.split('/', 1)
            if ca != cb:
                return ca < cb
        if '/' in a:
            return False
        elif '/' in b:
            return True
        else:
            return a < b

    def __eq__(self, other):
        return self.path.lower() == other.path.lower()

def sort_includes(includes):
    return sorted(includes, key=IncludeFileSorter)

def show_differences(bad, good):
    bad = ['    Current'] + bad
    good = ['    Should be'] + good
    longest = max(len(i) for i in bad)
    padded = [i + ' ' * (longest + 4 - len(i)) for i in bad]
    return '\n'.join('%s%s' % t for t in zip(padded, good))

def check_file(path):
    print('Checking %s' % path)
    try:
        try:
            data = open(path, encoding='utf-8').read()
        except TypeError: # py2
            data = open(path).read().decode('utf-8')
    except UnicodeDecodeError:
        sys.stderr.write('%s: bad UTF-8 data\n' % path)
        return

    lines = (l.strip() for l in data.split('\n'))
    lines = exclude_if_blocks(lines)
    includes = list(filter_includes(lines))
    sorted_includes = sort_includes(includes)
    if includes != sorted_includes:
        sys.stderr.write('%s: includes are incorrect\n' % path)
        sys.stderr.write(show_differences(includes, sorted_includes) + '\n')

if __name__ == '__main__':
    for path in sys.argv[1:]:
        check_file(path)