#!python
import argparse
import copy
import glob
import multiprocessing
import os
import re
import yaml
import json

# Colors
class colors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    PURPLE = '\033[95m'

# Script class
class Script:
    def __init__(self, *kargs):
        self.scripts = kargs

    def run(self, *args):
        for script in self.scripts:
            ret = os.system(script.format(*args))

            # TODO: print error
            if (ret != 0):
                exit(-1)

# Build class
class Build:
    cache = '.smake'
    bdir = '.smake/builds'
    tdir = '.smake/targets'

    # Default compiler and standard
    default_compiler = 'g++'
    default_standard = 'c++11'

    # Find header from include and source directories
    # TODO: needs optimization
    def get_header(self, source, header):
        # First check the header from source path
        sdir = os.path.dirname(source)
        hpath = sdir + '/' + header

        if os.path.exists(hpath):
            return os.path.normpath(hpath)

        # Check directories of the source first
        for sdir in self.sdirs:
            path = os.path.join(sdir, header)
            if os.path.exists(path):
                return path

        # Check include directories
        for idir in self.idirs:
            path = os.path.join(idir, header)
            if os.path.exists(path):
                return path
        
        # Assume its standard
        return None

    # Get header dependencies for a source
    #   which are not the standard library
    def get_deps(self, source):
        with open(source) as file:
            lines = file.readlines()

        for line in lines:
            if line.startswith('#include'):
                # TODO: join regexes?
                ht1 = re.search('#include <(.+)>', line)
                ht2 = re.search('#include \"(.+)\"', line)

                if ht1:
                    header = self.get_header(source, ht1.group(1))
                elif ht2:
                    header = self.get_header(source, ht2.group(1))
                
                if header and header not in self.all_deps:
                    self.all_deps.add(header)
                    self.get_deps(header)
                
        # Deep copy so that the dictionary does
        # not get modified from future dependency searches
        return copy.deepcopy(self.all_deps)

    # Build dependencies
    def build_deps(self, sources):
        # Dependency path
        dfile = self.bdir + '/' + self.build + '_dependencies.json'

        # Check if the file exists
        if os.path.exists(dfile):
            # Read the JSON
            with open(dfile) as file:
                deps = json.load(file)

            # Convert to set
            l = []
            for source in deps:
                l.append(source)

                for dep in deps[source]:
                    l.append(dep)

            # Check if the dependencies file needs to be updated
            mtime = os.path.getmtime(l[0])

            for source in l:
                mt = os.path.getmtime(source)
                if mt > mtime:
                    mtime = mt

            # If the dependencies file is up to date, return
            if mtime <= os.path.getmtime(dfile):
                self.deps = deps
                return

        # Create dependency map for each sort
        self.deps = dict()
        for source in sources:
            self.all_deps = set()   # Temporary variable for recursion

            # Get the dependencies and convert to a list
            sdeps = self.get_deps(source)
            sdeps = list(sdeps)

            self.deps.update({source: sdeps})

        # Store the dependencies to a file
        with open(dfile, 'w', encoding='utf-8') as file:
            json.dump(self.deps, file, ensure_ascii=False, indent=4)

    # Build is build name
    def __init__(self, target, build, compiler, sources,
        idirs = [], libs = [], ldirs = [], flags = []):
        # Set immediate properties
        self.target = target
        self.build = build
        self.sources = sources
        self.compiler = compiler

        # TODO: make function for directory creation
        # Create the cache directory
        if not os.path.exists(self.cache):
            os.mkdir(self.cache)

        # Create build dir
        if not os.path.exists(self.bdir):
            os.mkdir(self.bdir)
        
        # Create the output directory
        self.odir = self.bdir + '/' + self.build
        if not os.path.exists(self.odir):
            os.mkdir(self.odir)
        
        # Create target build directory
        self.tpath = self.tdir + '/' + self.target
        if not os.path.exists(self.tdir):
            os.mkdir(self.tdir)

        # Includes
        self.idirs = ''
        if len(idirs) > 0:
            self.idirs = ''.join([' -I ' + idir for idir in idirs])
        
        # Library dirs and rpath
        self.ldirs = ''
        if len(ldirs) > 0:
            self.ldirs = ''.join([' -L' + ldir for ldir in ldirs])

        # Runtime paths
        self.rpaths = ''
        if len(ldirs) > 0:
            self.rpaths = ' -Wl,'
            self.rpaths += ''.join(['-rpath ' + ldir for ldir in ldirs])

        # Libraries
        self.libs = ''
        if len(libs) > 0:
            self.libs = ''.join([lib if (lib[0] == '-') else (' -l' + lib) for lib in libs])

        # Flags
        self.flags = ''
        if len(flags) > 0:
            self.flags = ' '.join(flags)
        else:
            self.flags = '-std={}'.format(self.default_standard)

        # Populate the directory for each source
        self.sdirs = set()
        for source in sources:
            self.sdirs.add(os.path.dirname(source))

        # Build dependencies
        self.build_deps(sources)
        
    # Update target name
    def set_target(self, target):
        self.target = target
        self.tpath = self.tdir + '/' + self.target

    # Get most recent time from source and its dependencies
    def getmtime(self, file):
        mtime = os.path.getmtime(file)
        for dep in self.deps[file]:
            mt = os.path.getmtime(dep)
            if mt > mtime:
                mtime = mt
        return mtime

    # Generate compile_commands.json
    def gen_ccmds(self):
        # Return string
        ret = ''

        # Iterate through all the sources
        for file in self.sources:
            # Get file name without directory
            fname = os.path.basename(file)

            # Create the output file
            ofile = os.path.join(self.odir, fname.replace('.cpp',
                '.o').replace('.c', '.o'))

            # Generate command
            cmd = '{} {} -c -o {} {} {}'.format(self.compiler, self.flags,
                    ofile, file, self.idirs)

            ret += '{{\n  "directory": "{}",\n  "command": "{}",\n  "file": "{}"\n}},\n'.format(
                    os.getcwd(), cmd, file)

        # Return the string
        return ret


    # Compile a single file
    def compile(self, lock_avaiable, file, msg, verbose = False):
        # Lock if available
        if lock_avaiable:
            lock.acquire()

        # Print the messahe
        print(msg, end='')

        # Check if file exists
        if not os.path.exists(file):
            print(colors.FAIL + f'\tFile {file} does not exist' + colors.ENDC)
            return ''

        # Get file name without directory
        fname = file.replace('/', '_')

        # Create the output file
        ofile = os.path.join(self.odir, fname.replace('.cpp',
            '.o').replace('.c', '.o'))

        # Check if compilation is necessary
        file_t = self.getmtime(file)

        if os.path.exists(ofile):
            ofile_t = os.path.getmtime(ofile)
            if ofile_t > file_t:
                print(colors.OKCYAN + 'Source already compiled' + colors.ENDC)
        
                # Unlock if available
                if lock_avaiable:
                    lock.release()
                
                return ofile

        # Compile the source
        cmd = '{} {} -c -o {} {} {}'.format(self.compiler, self.flags,
                ofile, file, self.idirs)

        # Print command if verbose
        if verbose:
            print(colors.PURPLE + cmd + colors.ENDC)
        else:
            print()
        
        # Unlock if available
        if lock_avaiable:
            lock.release()

        # Run command and check if it was successful
        ret = os.system(cmd)
        if ret != 0:
            return ''

        # Return object
        return ofile
    
    # Link source objects
    def link(self, compiled, fsources, verbose):
        # TODO: avoid repeated linking

        # Check if any of the sources failed to compile
        if len(fsources) > 0:
            # Print message
            print('\n' + colors.FAIL + 'Failed to compile the' + \
                ' following sources,' + ' skipping linking process' + \
                colors.ENDC)

            # Print failed sources
            for file in fsources:
                print(colors.PURPLE + '\t' + file + colors.ENDC)

            # Return empty for failure
            return ''
        
        # Command generation
        compiled_str = ' '.join(compiled)
        cmd = f'{self.compiler} {compiled_str} -o {self.tpath} {self.ldirs} {self.libs} {self.rpaths}'

        # Print message
        s = colors.OKBLUE + f'Linking executable {self.target}... ' + colors.ENDC
        print(s.ljust(75), end='')

        # Print command if verbose
        if verbose:
            print(colors.PURPLE + cmd + colors.ENDC)
        else:
            print()

        # Check return of linking
        ret = os.system(cmd)
        if ret != 0:
            # Print message and return empty
            print(colors.FAIL + f'Failed to link target {self.target}' + \
                    colors.ENDC)

            return ''
        
        return self.tpath

    # Compile all sources
    def run(self, verbose = False):
        # List of compiled files
        compiled = []

        # Failed sources
        fsources = []

        # Generate format string
        slen = str(len(self.sources))
        fstr = colors.OKCYAN + '[{:>' + str(len(slen)) + '}/' + slen + '] ' \
            + colors.OKGREEN + 'Compiling {}... ' + colors.ENDC

        # Compile the sources
        for i in range(len(self.sources)):
            # Create the message
            msg = fstr.format(i + 1, self.sources[i])

            # Compile the file (or attempt to)
            ofile = self.compile(False, self.sources[i], msg, verbose)

            # Check failure
            if len(ofile) == 0:
                fsources.append(self.sources[i])

            # Add to compiled list
            compiled.append(ofile)
        
        # Link all the objects
        return self.link(compiled, fsources, verbose)
    
    # Make lock available to processes
    def init(self, l):
        global lock
        lock = l
    
    # Compile multithreaded
    def mt_run(self, threads, verbose = False):
        # Interprocess lock
        lock = multiprocessing.Lock()

        # Generate format string
        slen = str(len(self.sources))
        fstr = colors.OKCYAN + '[{:>' + str(len(slen)) + '}/' + slen + '] ' \
            + colors.OKGREEN + 'Compiling {}... ' + colors.ENDC
        
        # Create the arguments
        args = []
        for i in range(len(self.sources)):
            msg = fstr.format(i + 1, self.sources[i])
            args.append((True, self.sources[i], msg, verbose))

        with multiprocessing.Pool(threads, initializer = self.init, initargs=(lock,)) as pool:
            rets = pool.starmap(self.compile, args)

        # Compiled and failed
        compiled = []
        fsources = []

        for obj in rets:
            if len(obj) == 0:
                fsources.append(obj)
            else:
                compiled.append(obj)
        
        # Link the objects
        return self.link(compiled, fsources, verbose)

# Target class
class Target:
    def __init__(self, name, modes, builds, postbuilds):
        self.name = name
        self.modes = modes
        self.builds = builds
        self.postbuilds = postbuilds

        # Check that each mode has a build

    # Generate compile_commands.json
    def gen_ccmds(self):
        # Return string
        ret = ''

        # Run in each build
        for build in self.builds:
            ret += self.builds[build].gen_ccmds()

        return ret

    def run(self, mode = 'default', threads = 0, verbose = False):
        # Empty is always a valid mode, but `default should be used'
        if len(mode) == 0:
            mode = 'default'
        
        # Retrieve run attributes
        if not (mode in self.modes):
            print(colors.FAIL + f'{mode} is not a valid build for target {self.name}' + colors.ENDC)
            exit(-1)

        build = self.builds[mode]

        # Run the build
        if threads > 0:
            target = build.mt_run(threads, verbose)
        else:
            target = build.run(verbose)

        # Run postbuild with target argument, if present
        if mode in self.postbuilds:
            if len(target) == 0:
                print(colors.FAIL + '\nFailed to compile target, skipping postbuild script' + colors.ENDC)
                return

            postbuild = self.postbuilds[mode]
            print(colors.OKBLUE + '\nSucessfully compiled target, ' + \
                'running postbuild script\n' + colors.ENDC)
            postbuild.run(target)

# Global helper functions
def split_plain(elem):
    prop = elem
    if isinstance(elem, str):
        prop = prop.split(', ')
    return prop

def split(d, pr, defns):
    prop = split_plain(d[pr])

    out = []
    for i in range(len(prop)):
        if prop[i] in defns:
            value = defns[prop[i]]

            if isinstance(value, list):
                out.extend(value)
            else:
                out.append(value)
        else:
            out.append(prop[i])
    
    return out

def concat(ldicts):
    out = {}
    for d in ldicts:
        out.update(d)
    return out

# Config class
class Config:
    # Constructor takes no argument
    def __init__(self):
        # Initialize targets to empty
        self.targets = {}

        # Set installations to empty
        self.installs = {}

        # Get config file from current dir
        if os.path.exists('smake.yaml'):
            self.load_file('smake.yaml')
        else:
            print(colors.FAIL + 'No smake config file (smake.yaml)' + \
                ' in current directory' + colors.ENDC)
            exit(-1)

    # Reads definitions from variables like sources, includes, etc.
    def load_definitions(self, smake):
        # TODO: error on duplicate definition

        # Output dictionary
        defns = {}
        
        # Load definitions
        if 'definitions' in smake:
            for dgroup in smake['definitions']:
                key, value = next(iter(dgroup.items()))
                value = split_plain(value)
                defns.update({key: value})
        
        # Default compiler and standard
        if 'default_compiler' in smake:
            Build.default_compiler = smake['default_compiler']
        
        if 'default_standard' in smake:
            Build.default_standard = smake['default_standard']
        
        # Return the dictionary
        return defns
    
    # Create a build object
    def load_build(self, build, defns):
        name = list(build)[0]
        properties = {}
        for d in build[name]:
            properties.update(d)

        # Check unused sections
        used = ['sources', 'idirs', 'libraries', 'ldirs', 'flags', 'compiler']
        for key in properties:
            if key not in used:
                print(colors.WARNING + f'Section \"{key}\" in builds is not used by smake' + colors.ENDC)

        # Preprocess properties
        if 'sources' not in properties:
            print(colors.FAIL + f'Build \"{name}\" has no listed sources' + colors.ENDC)
            exit(-1)

        sources = split(properties, 'sources', defns)

        # Optional properties
        includes = []
        libraries = []
        ldirs = []
        flags = []
        compiler = Build.default_compiler

        if 'idirs' in properties:
            includes = split(properties, 'idirs', defns)
        
        if 'libraries' in properties:
            libraries = split(properties, 'libraries', defns)

        if 'ldirs' in properties:
            ldirs = split(properties, 'ldirs', defns)
        
        if 'flags' in properties:
            flags = split(properties, 'flags', defns)
        
        if 'compiler' in properties:
            compiler = properties['compiler']

        # Create and return the object
        return Build('smake-build', name, compiler, sources, includes,
            libraries, ldirs, flags)

    def load_all_builds(self, smake, defns):
        # Check that builds actually exists
        if 'builds' not in smake:
            print(colors.FAIL + 'No builds defined in smake.yaml' + colors.ENDC)
            exit(-1)

        blist = {}
        for b in smake['builds']:
            name = list(b)[0]
            build = self.load_build(b, defns)
            blist.update({name: build})
        
        return blist

    def load_target(self, target, blist, defns):
        name = list(target)[0]
        properties = {}
        for d in target[name]:
            properties.update(d)

        # Warn unused sections
        used = ['modes', 'builds', 'postbuilds']
        for key in properties:
            if key not in used:
                print(colors.WARNING + f'Section {key} in targets is not used by smake' + colors.ENDC)
        
        # Preprocess properties
        modes = ['default']
        if 'modes' in properties:
            modes.extend(split(properties, 'modes', defns))

        # Gets builds and postbuilds
        builds = concat(properties['builds'])

        postbuilds = {}
        if 'postbuilds' in properties:
            postbuilds = concat(properties['postbuilds'])

        # Preprocess predefined things
        # TODO: separate methods
        for b in builds:
            bname = builds[b]

            if bname in blist:
                builds[b] = blist[bname]
                builds[b].set_target(name)
            # TODO: errir handling here
        
        # If the postbuild is a string, then convert to Script
        for pe in postbuilds:
            pname = postbuilds[pe]

            if pname in defns:
                postbuilds[pe] = defns[pname]
            else:
                postbuilds[pe] = Script(pname)

        return Target(name, modes, builds, postbuilds)

    def load_all_targets(self, smake, builds, defns):
        # Check that targets exist
        if 'targets' not in smake:
            print(colors.FAIL + 'No targets specified in smake.yaml' + colors.ENDC)
            exit(-1)

        tlist = {}
        for t in smake['targets']:
            name = list(t)[0]
            target = self.load_target(t, builds, defns)
            tlist.update({name: target})

        return tlist
    
    # Read all installs
    def load_all_installs(self, smake, defns):
        if 'installs' not in smake:
            return {}
            
        installs = {}
        for install in smake['installs']:
            for script in install:
                slist = split_plain(install[script])
                installs.update({script: Script(*slist)})
        return installs

    # Read config file
    def load_file(self, file):
        # Open and read the config
        with open(file, 'r') as file:
            smake = yaml.safe_load(file)

        # Check unused sections
        # TODO: line numbers
        used = ['default_compiler', 'default_standard',
            'definitions', 'builds', 'targets', 'installs']
        
        for section in smake:
            if section not in used:
                print(colors.WARNING + f'Section {section} is not used by smake' + colors.ENDC)
        
        # Load the definitions
        defns = self.load_definitions(smake)

        # Load all builds
        builds = self.load_all_builds(smake, defns)

        # Load all targets
        self.targets = self.load_all_targets(smake, builds, defns)

        # Load all installs
        self.installs = self.load_all_installs(smake, defns)

    # List all targets
    def list_targets(self):
        # Return if no targets
        if len(self.targets) == 0:
            print(colors.FAIL + 'No targets found' + colors.ENDC)
            exit(-1)

        # Max string length of target name
        maxlen = 0
        for t in self.targets:
            if len(t) > maxlen:
                maxlen = len(t)
        
        # Padding
        maxlen += 5

        # Header message
        fmt = colors.OKCYAN + '{:<' + str(maxlen) + '}{}' + colors.ENDC
        print(fmt.format('Target', 'Modes'))

        # Print the targets
        for t in self.targets:
            modes = ''
            for i in range(len(self.targets[t].modes)):
                modes += self.targets[t].modes[i]

                if i != len(self.targets[t].modes) - 1:
                    modes += ', '
            
            fmt = colors.OKBLUE + '{:<' + str(maxlen) + '}' + \
                colors.PURPLE + '{}' + colors.ENDC
            print(fmt.format(t, modes))

    # Generate compile_commands.json
    def gen_ccmds(self):
        # Final string
        ccmds = ''

        # Run the function on all targets
        for t in self.targets:
            ccmds += self.targets[t].gen_ccmds()

        # Last preprocessing
        ccmds = '[\n' + ccmds.rstrip(',\n') + '\n]'

        # Write to file
        with open('compile_commands.json', 'w') as file:
            file.write(ccmds)
    
    # Run the correct target
    def run(self, target, mode = 'default', threads = 0, verbose = False):
        # Check if targets are empty
        if len(self.targets) == 0:
            print(colors.FAIL + 'No targets in smake' + colors.ENDC)
            exit(-1)

        # If the target is all, then run all targets
        if target == 'all':        
            for t in self.targets:
                self.targets[t].run(mode, threads, verbose)
            return

        # Check if the target is valid
        if target in self.targets:
            self.targets[target].run(mode, threads, verbose)
        else:
            print(colors.FAIL + f'No target {target} found.' + \
                ' Perhaps you meant one of the following:' + colors.ENDC)
            for valid in list(self.targets.keys()):
                print(colors.PURPLE + f'\t{valid}' + colors.ENDC)
    
    # Run installation for a target
    def install(self, target):
        # Check presence of install target
        if target not in self.installs:
            print(colors.FAIL + f'No target \"{target}\" in installs section' + colors.ENDC)
            exit(-1)
        
        # Get the script and execute it
        script = self.installs[target]
        script.run('')

# Script as executable
if __name__ == '__main__':
    # Build the parser
    parser = argparse.ArgumentParser()

    # {TARGET} -m {EXECUTOR} -j{THREADS}
    parser.add_argument('target', help="Target name", nargs='?', default='all')
    parser.add_argument('-m', "--mode", help="Execution mode", default='')
    parser.add_argument('-j', "--threads",
                        help = "Number of concurrent threads", type = int, default = 0)

    # Special args
    parser.add_argument('-l', '--list', help = 'List all targets', action = 'store_true')
    parser.add_argument('-C', '--clear-cache', help = 'Clear smake cache', action = 'store_true')
    parser.add_argument('-G', '--gen-ccmds', help = 'Generate compile_commands.json', action = 'store_true')
    parser.add_argument('-v', '--verbose', help = 'Verbose mode', action = 'store_true')
    parser.add_argument('-I', '--install', help = 'Install target', action = 'store_true')

    # Read the arguments
    args = parser.parse_args()
    
    # No need to read config for this
    if args.clear_cache:
        os.system('rm -rf .smake')
        exit(0)

    # Create the local config
    config = Config()

    # Create compile_commands.json
    if args.gen_ccmds:
        config.gen_ccmds()
        exit(0)

    # Run the target
    if args.list:
        config.list_targets()
    elif args.install:
        config.install(args.target)
    else:
        config.run(args.target, args.mode, args.threads, args.verbose)
