#!/usr/bin/env python3
"""SameStr: Shared Strains Identification in Metagenomic Samples."""


import glob
import logging
import os
import pathlib

from datetime import datetime
from os.path import basename, dirname, realpath, isdir, isfile
from os import environ, makedirs
from sys import stderr, argv

import numpy as np

from samestr import __version__

from samestr.convert import sam2bam, concatenate_gene_files, bam2freq
from samestr.compare import compare
from samestr.db import generate_db
from samestr.extract import ref2freq
from samestr.filter import filter_freqs
from samestr.merge import freq2freqs
from samestr.stats import aln2stats
from samestr.summarize import summarize
from samestr.utils.citations import citations
from samestr.utils.file_mapping import (
	clade_path,
    get_uniform_extension, 
    set_output_structure,
    spread_args_by_input_files,
)
from samestr.utils.ooSubprocess import parallelize_async, serialize
from samestr.utils.parse_args import read_params
from samestr.utils.utilities import (
	list_group_by_basename,
	load_samestr_db_manifest,
	read_json,
	write_json,
)


SAMESTR_DIR = dirname(realpath(__file__))
CONVERT_DIR = os.path.join(SAMESTR_DIR, 'convert', '')
UTILS_DIR = os.path.join(SAMESTR_DIR, 'utils', '')
environ['PATH'] += ':' + UTILS_DIR + ':' + CONVERT_DIR


def samestr(input_args):
    """Main function for SAMESTR."""
    
    ## Log command     
    samestr_cmd = {
            'version': f'{__version__}',
            'command': ' '.join(argv),
            'date-start': '%s' % datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }

    ## Get database manifest files
    if input_args['command'] == 'db':
        if not input_args['db_check']:
            for o in ['db_version', 'markers_fasta', 'markers_info']:
                if o not in input_args:
                    LOG.error('The following argument is required: %s' % o)
                    exit(1)
            input_args['marker_dir'] = input_args['output_dir']
    db_existed, db_manifest, db_clades, db_taxonomy = load_samestr_db_manifest(input_args['marker_dir'])

    #### Print Info
    if db_existed:
        LOG.info('Processing data with existing SameStr database based on %s, %s' % (db_manifest['database']['name'], db_manifest['database']['version']))
        LOG.info('Database contains %s files, %s clades, %s markers, totalling %s positions.' % (db_manifest['records']['total_n_files'],
                                                                                              db_manifest['records']['total_n_clades'],
                                                                                              db_manifest['records']['total_n_markers'],
                                                                                              db_manifest['records']['total_n_positions']
                                                                                              ))
    elif not db_existed and input_args['command'] != 'db':
        LOG.error('Aborting. Could not parse database at %s' % input_args['output_dir'])
        exit(1)

    #### Attach to input var, don't update manifest yet.
    input_args['samestr_db'] = {
        'db_existed': db_existed,
        'db_manifest': db_manifest,
        'db_clades': db_clades,
        'db_taxonomy': db_taxonomy,
        'samestr_cmd': samestr_cmd
    } 

    #### check db integrity if requested and db existed
    if input_args.get("db_check", False):
        if not db_existed:
            LOG.error('Cannot conduct a SameStr database check since a database was not provided.')
            exit(1)
        else:
            t_missing = []
            t_found = 0
            LOG.info('Checking integrity of the SameStr database.')
            expected_file_bases = ['%s/%s' % (input_args['marker_dir'], clade_path(c, filebase = True)) for c in db_clades['records'].keys()]
            for e in  expected_file_bases:
                suffix = ['.contig_map.txt.gz', '.gene_file.txt.gz', '.markers.fa.gz', '.positions.txt.gz']
                for s in suffix:
                    t = e + s
                    if not isfile(t):
                        t_missing.append(t)
                    else:
                        t_found += 1

            if t_found == db_manifest['records']['total_n_files']:
                LOG.info('Number of files in the SameStr database matches the manifest.')
            if len(t_missing) > 0:
                outfn =os.path.join(input_args['output_dir'], 'db_check.tsv')
                LOG.warning('Files missing from the SameStr database: %s. Saving to %s' % (len(t_missing), outfn))
                if not isdir(input_args['output_dir']):
                    makedirs(input_args['output_dir'])
                with open(outfn, 'w') as f:
                    txt = '\n'.join(t_missing)
                    f.write(txt)
            exit(0)
    
    ## Check input file extensions
    accepted_extensions_dict = {
        'db': ['.fasta', '.fa', '.fna', '.fasta.bz2', '.fa.bz2', '.fna.bz2'],
        'convert': ['.sam', '.sam.bz2', '.bam'],
        'extract': ['.fasta', '.fa', '.fna', '.fasta.gz', '.fa.gz', '.fna.gz'],
        'merge': ['.npy', '.npy.gz', '.npz'],
        'filter': ['.npy', '.npy.gz', '.npz'],
        'stats': ['.npy', '.npy.gz', '.npz'],
        'compare': ['.npy', '.npy.gz', '.npz'],
        'summarize': ['']
    }

    #### preprocess input/output files & extensions and map to commands
    accepted_extensions = accepted_extensions_dict[input_args['command']]
    if input_args['command'] == 'db':
        input_args['input_extension'] = get_uniform_extension(
            [input_args['markers_fasta']], accepted_extensions)
    elif input_args['command'] not in ('summarize',):

        if input_args['command'] == 'merge':
            if not input_args['input_dir'] and not input_args['input_files']:
                raise ValueError("Neither --input-files not --input-dir are set.")
            elif input_args['input_files'] and input_args['input_dir']:
                raise ValueError("--input-files and --input-dir cannot be used simultaneously.")            
            elif args['input_dir']:
                args['input_files'] = glob.glob(os.path.join(args['input_dir'], "**", "*.npz"), recursive=True,)
                if not args['input_files']:
                    raise ValueError(f"Cannot find any .npz files in {args['input_dir']}.")
                
        input_args['input_extension'] = get_uniform_extension(
            input_args['input_files'], accepted_extensions)

        # Count input files
        file_count = len(input_args['input_files'])
        LOG.debug('Number of input files: %s', file_count)

    ## check make output dir
    if not isdir(input_args['output_dir']):
        pathlib.Path(input_args['output_dir']).mkdir(exist_ok=True, parents=True)

    input_clade = input_args.get("clade")

    ################################
    ## process individual commands
    if input_args['command'] == 'db':
        
        outdir =os.path.join(input_args['output_dir'], 'db_markers', '')
        if not isdir(outdir):
            makedirs(outdir)

        # expand and generate db from markers/info files
        generate_db(input_args)

    elif input_args['command'] == 'extract':
        cmd_args = serialize(ref2freq, [input_args])

    elif input_args['command'] == 'convert':
        # TODO: offer/fix dir-based input
        # input_args['input_files'], input_args['input_extension'] = collect_input_files(
        #     input_args['input_dir'],
        #     accepted_extensions,
        #     recursive=input_args["recursive_input"]
        # )
        
        input_args['input_files'] = list_group_by_basename(
            input_args['input_files'],
            cut_name_endings=[input_args['input_extension']])

        # Spread args over files/file-pairs, Set output names and check/make their dir
        input_args['input_sequence_type'] = 'single'
        cmd_args = spread_args_by_input_files(input_args)
        cmd_args = set_output_structure(cmd_args)

        # Run: convert
        cmd_args = parallelize_async(concatenate_gene_files, cmd_args,
                                 input_args['nprocs'])
        
        cmd_args = parallelize_async(sam2bam, cmd_args, input_args['nprocs'])        
        cmd_args = parallelize_async(bam2freq, cmd_args, input_args['nprocs'])

    elif input_args['command'] == 'extract':
        cmd_args = serialize(ref2freq, [input_args])

    elif input_args['command'] == 'merge':

        clade_file_dict = {}

        # group input files by clade
        for file_path in sorted(args['input_files']):
            fn = basename(file_path)
            clade = fn.split('.')[0]

            # append list if file has been merged before
            sample_names = file_path.replace(
                input_args['input_extension'], '.names.txt'
            )
            if isfile(sample_names):
                with open(sample_names, 'r', encoding="utf8") as s_n:
                    sample = s_n.read().strip().split('\n')
            else:
                sample = fn.split(input_args['input_extension'])[0].replace(f'{clade}.', '')

            # only process selected clades (or all clades if no selected clade)
            if input_clade is None or clade in input_clade:
                clade_file_dict.setdefault(clade, []).append([sample, file_path])

        cmd_args = [
            {
                "clade": clade,
                "input_files": files,
                "output_dir": input_args["output_dir"],
            }
            for clade, files in clade_file_dict.items()
        ]        

        cmd_args = parallelize_async(freq2freqs, cmd_args, input_args['nprocs'])

    elif input_args['command'] in ['filter', 'stats', 'compare']:

        # For each species: group freqs with resp names
        freqs = {
            basename(fn).split(input_args['input_extension'])[0]: [fn]
            for fn in input_args['input_files']
        }

        for name in input_args['input_names']:
            clade = basename(name).split('.names.txt')[0]
            if clade not in freqs:
                LOG.warning('Skipping %s. Found name file '
                            'but no SNV profile.', clade)
            else:
                freqs[clade].append(name)

        for clade, freq in list(freqs.items()):
            if len(freq) != 2:
                LOG.warning('Skipping [%s]. Found SNV profile '
                            'but no name file: %s' % (clade, ', '.join(freq)))
                freqs.pop(clade)
            elif input_clade is not None and clade not in input_clade:
                freqs.pop(clade)

        # attach sample selection file
        if 'samples_select' in input_args:
            for i_s in input_args['samples_select']:
                clade = basename(i_s).split('.select.txt')[0]
                if clade not in freqs:
                    LOG.warning('Skipping %s. Found selection file '
                                'but no SNV profile.', clade)
                else:
                    freqs[clade].append(i_s)

        for clade, freq in list(freqs.items()):
            if len(freq) > 3:
                LOG.warning('Skipping %s. Found more than one '
                            'sample selection file: %s' % (clade, ', '.join(freq)))
                freqs.pop(clade)
            elif len(freq) == 2:
                freq += [None]
            elif len(freq) != 3:
                LOG.error('Unexpected number of files (%s): %s.', len(freq), clade)
                exit(0)  

        # Spread args over clades
        cmd_args = []
        ignore_args = {'input_files', 'input_names', 'input_select', 'clade'}

        arg_template_d = {
            arg: value
            for arg, value in input_args.items()
            if arg not in ignore_args
        }

        for clade, (input_file, input_name, input_select) in freqs.items():
            
            args_d = {
                "input_file": input_file,
                "input_name": input_name,
                "input_select": input_select,
                "clade": clade,
            }

            args_d.update(arg_template_d)

            cmd_args.append(args_d)
            
        if input_args['command'] == 'filter':
            LOG.info('Filtering files: %s', len(freqs))

            # Run: filter
            cmd_args = parallelize_async(filter_freqs, cmd_args, input_args['nprocs'])

        elif input_args['command'] == 'stats':
            LOG.info('Gathering statistics: %s', len(freqs))

            # Run: stats
            cmd_args = parallelize_async(aln2stats, cmd_args, input_args['nprocs'])

        elif input_args['command'] == 'compare':
            LOG.info('Comparing alignments: %s', len(freqs))

            # Run: compare
            cmd_args = parallelize_async(compare, cmd_args, input_args['nprocs'])

    elif input_args['command'] == 'summarize':
        cmd_args = serialize(summarize, [input_args])
    
    # write current command to logfile
    samestr_cmd['date-end'] = '%s' % datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    logfile_fn =os.path.join(input_args['output_dir'], 'command_log.json')
    if isfile(logfile_fn):
        logfile = read_json(logfile_fn)
        iter = max([int(k) for k in logfile.keys()]) + 1
    else:
        logfile = {}
        iter = 1
    logfile[iter] = samestr_cmd
    write_json(logfile_fn, logfile)
    LOG.info('Done.')


if __name__ == "__main__":
    args = read_params()

    if args['citation'] is not None:
        print(citations[args['citation']])
        exit(0)

    if args.get('random_seed') is not None:
        np.random.seed(args['random_seed'])

    logging.basicConfig(
    level=getattr(logging, args['verbosity']),
    stream=stderr,
    format='%(asctime)s | %(levelname)s | %(name)s | %(funcName)s | %(lineno)d | %(message)s'
    )
    LOG = logging.getLogger(__name__)

    samestr(args)
