#!/usr/bin/env python3.6

import argparse
import subprocess
import sys
import zipfile
from pathlib import Path

import yaml
import tqdm

import knock_knock

def check_blastn():
    no_blastn = False

    try:
        output = subprocess.check_output(['blastn', '-version'])
        if b'2.7.1' not in output:
            no_blastn = True
    except:
        no_blastn = True

    if no_blastn:
        print('blastn 2.7.1 is required and couldn\'t be found')
        sys.exit(0)

def check_parallel():
    no_parallel = False

    try:
        output = subprocess.check_output(['parallel', '--version'])
        if not output.startswith(b'GNU parallel'):
            no_parallel = True
    except:
        no_parallel = True

    if no_parallel:
        print('GNU parallel is required and couldn\'t be found')
        sys.exit(0)

def parallel(args):
    from knock_knock import experiment

    check_parallel()

    if args.group:
        args.conditions['group'] = args.group

    exps = experiment.get_all_experiments(args.project_directory, args.conditions)

    if len(exps) == 0:
        print('No experiments satify conditions:')
        print(args.conditions)
        sys.exit(0)

    parallel_command = [
        'parallel',
        '--will-cite',
        '-n', '4', 
        '--bar',
        '--max-procs', str(args.max_procs),
        'knock-knock',
        'process',
        str(args.project_directory),
        ':::',
    ]

    arg_tuples = [(exp.group, exp.name, '--stages', args.stages) for exp in exps]
    for t in sorted(arg_tuples):
        parallel_command.extend(t)

    subprocess.check_call(parallel_command)

def process(args):
    from knock_knock import experiment

    check_blastn()

    sample_sheet = experiment.load_sample_sheet(args.project_directory, args.group)

    if sample_sheet is None:
        print(f'Error: {args.group} not found in {args.project_directory}')
        sys.exit(1)
    elif args.sample not in sample_sheet:
        print(f'Error: {args.sample} not found in {args.group} sample sheet')
        sys.exit(1)
    else:
        description = sample_sheet[args.sample]

        if description.get('platform') == 'pacbio':
            exp_class = experiment.PacbioExperiment
        elif description.get('platform') == 'illumina':
            exp_class = experiment.IlluminaExperiment
        else:
            raise ValueError(description)

    exp = exp_class(args.project_directory, args.group, args.sample, description, args.progress)

    stages = args.stages.split(',')
    for stage in stages:
        exp.process(stage)

def make_tables(args):
    from knock_knock import experiment, table

    results_dir = args.project_directory / 'results'

    if args.group:
        groups = [args.group]
    else:
        groups = experiment.get_all_groups(args.project_directory)

        results_dir = args.project_directory / 'results'
        csv_fn = (results_dir / 'all_groups').with_suffix('.csv')
        df = table.load_counts(args.project_directory, exclude_empty=False).T
        df.to_csv(csv_fn)

    for group in groups:
        print(group)

        conditions = {'group': group}

        fns_to_zip = []

        print('Generating high-level html table...')
        html_fn = (results_dir / group).with_suffix('.html')
        table.generate_html(args.project_directory, html_fn, conditions, show_details=False)
        fns_to_zip.append(html_fn)

        print('Generating detailed html table...')
        html_fn = (results_dir / f'{group}_with_details').with_suffix('.html')
        table.generate_html(args.project_directory, html_fn, conditions, show_details=True)
        fns_to_zip.append(html_fn)

        print('Generating csv table...')
        csv_fn = (results_dir / group).with_suffix('.csv')
        df = table.load_counts(args.project_directory, conditions, exclude_empty=False).T
        df.to_csv(csv_fn)
        fns_to_zip.append(csv_fn)
        
        print('Generating performance metrics...')
        pms_fn = (results_dir / f'{group}_performance_metrics').with_suffix('.csv')
        pms = table.calculate_performance_metrics(args.project_directory, conditions)
        pms.to_csv(pms_fn)
        fns_to_zip.append(pms_fn)

        exps = experiment.get_all_experiments(args.project_directory, conditions)

        exps_missing_files = set()

        for exp in exps:
            def add_fn(fn):
                if not fn.exists():
                    exps_missing_files.add((exp.group, exp.name))
                else:
                    if fn.is_dir():
                        for child_fn in fn.iterdir():
                            fns_to_zip.append(child_fn)
                    else:
                        fns_to_zip.append(fn)
            
            add_fn(exp.fns['outcome_browser'])

            for outcome in exp.outcomes:
                outcome_fns = exp.outcome_fns(outcome)
                add_fn(outcome_fns['diagrams_html'])
                add_fn(outcome_fns['first_example'])
                add_fn(outcome_fns['length_ranges_dir'])

            categories = set(c for c, s in exp.outcomes)
            for category in categories:
                outcome_fns = exp.outcome_fns(category)
                add_fn(outcome_fns['diagrams_html'])
                add_fn(outcome_fns['first_example'])

        if exps_missing_files:
            print(f'Warning: {len(exps_missing_files)} experiment(s) are missing output files:')
            for group, exp_name in sorted(exps_missing_files):
                print(f'\t{group} {exp_name}')

        zip_fn = (results_dir / group).with_suffix('.zip')
        archive_base = Path(group)
        with zipfile.ZipFile(zip_fn, mode='w', compression=zipfile.ZIP_DEFLATED) as zip_fh:
            description = 'Zipping table files'
            for fn in tqdm.tqdm(fns_to_zip, desc=description):
                arcname = archive_base / fn.relative_to(results_dir)
                zip_fh.write(fn, arcname=arcname)

def build_targets(args):
    from knock_knock import build_targets
    build_targets.build_target_infos_from_csv(args.project_directory)

def design_primers(args):
    from knock_knock import build_targets
    build_targets.design_amplicon_primers_from_csv(args.project_directory, args.genome)

def build_indices(args):
    from knock_knock import build_targets
    build_targets.download_genome_and_build_indices(args.project_directory, args.genome_name, args.num_threads)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog='knock-knock')

    parser.add_argument('--version', action='version', version=knock_knock.__version__)

    subparsers = parser.add_subparsers(dest='subcommand', title='subcommands')
    subparsers.required = True

    parser_process = subparsers.add_parser('process', help='process a single sample')
    parser_process.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_process.add_argument('group', help='group name')
    parser_process.add_argument('sample', help='sample name')
    parser_process.add_argument('--progress', const=tqdm.tqdm, action='store_const', help='show progress bars')
    parser_process.add_argument('--stages', default='align,categorize,visualize')
    parser_process.set_defaults(func=process)

    parser_parallel = subparsers.add_parser('parallel', help='process multiple samples in parallel')
    parser_parallel.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_parallel.add_argument('max_procs', type=int, help='maximum number of samples to process at once')
    parser_parallel.add_argument('--group', help='if specified, the single group name to process; if not specified, all groups will be processed')
    parser_parallel.add_argument('--conditions', type=yaml.safe_load, default={}, help='if specified, conditions that samples must satisfy to be processed, given as yaml; if not specified, all samples will be processed')
    parser_parallel.add_argument('--stages', default='align,categorize,visualize')
    parser_parallel.set_defaults(func=parallel)

    parser_table = subparsers.add_parser('table', help='generate tables of outcome frequencies')
    parser_table.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_table.add_argument('--group', help='if specified, the single group name to generate tables for; if not specified, all groups will be generated')
    parser_table.set_defaults(func=make_tables)

    parser_targets = subparsers.add_parser('build_targets', help='build annotations of target locii')
    parser_targets.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_targets.set_defaults(func=build_targets)

    parser_primers = subparsers.add_parser('design_primers', help='design amplicon primers for sgRNAs')
    parser_primers.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_primers.set_defaults(func=design_primers)

    parser_indices = subparsers.add_parser('build_indices', help='download a reference genome and build alignment indices')
    parser_indices.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_indices.add_argument('genome_name', help='name of genome to download')
    parser_indices.add_argument('--num_threads', type=int, default=8, help='number of threads to use for index building')
    parser_indices.set_defaults(func=build_indices)

    args = parser.parse_args()
    args.func(args)
