#!/usr/bin/env python3.6

import argparse
import subprocess
import sys
import zipfile
from pathlib import Path

import yaml
import tqdm

import knock_knock


citation = '''
Hera Canaj, Jeffrey A. Hussmann, Han Li, Kyle A. Beckman, Leeanne Goodrich,
Nathan H. Cho1, Yucheng J. Li, Daniel Santos, Aaron McGeever, Veronica Pessino,
Cindy Huang, Li Gan, Barbara Panning, Bo Huang, Jonathan S. Weissman and
Manuel D. Leonetti. Deep profiling reveals the complexity of integration
outcomes in CRISPR knock-in experiments. biorxiv (2019).
'''

def check_blastn():
    no_blastn = False

    try:
        output = subprocess.check_output(['blastn', '-version'])
        if b'2.7.1' not in output:
            no_blastn = True
    except:
        no_blastn = True

    if no_blastn:
        print('blastn 2.7.1 is required and couldn\'t be found')
        sys.exit(0)

def check_parallel():
    no_parallel = False

    try:
        output = subprocess.check_output(['parallel', '--version'])
        if not output.startswith(b'GNU parallel'):
            no_parallel = True
    except:
        no_parallel = True

    if no_parallel:
        print('GNU parallel is required and couldn\'t be found')
        sys.exit(0)

def parallel(args):
    from knock_knock import experiment

    check_parallel()

    if args.group:
        args.conditions['group'] = args.group

    exps = experiment.get_all_experiments(args.project_directory, args.conditions)

    if len(exps) == 0:
        print('No experiments satify conditions:')
        print(args.conditions)
        sys.exit(0)

    parallel_command = [
        'parallel',
        '--will-cite',
        '-n', '4', 
        '--bar',
        '--max-procs', str(args.max_procs),
        'knock-knock',
        'process',
        str(args.project_directory),
        ':::',
    ]

    arg_tuples = [(exp.group, exp.name, '--stages', args.stages) for exp in exps]
    for t in sorted(arg_tuples):
        parallel_command.extend(t)

    subprocess.check_call(parallel_command)

def process(args):
    from knock_knock import experiment

    check_blastn()

    sample_sheet = experiment.load_sample_sheet(args.project_directory, args.group)

    if sample_sheet is None:
        print(f'Error: {args.group} not found in {args.project_directory}')
        sys.exit(1)
    elif args.sample not in sample_sheet:
        print(f'Error: {args.sample} not found in {args.group} sample sheet')
        sys.exit(1)
    else:
        description = sample_sheet[args.sample]

        if description.get('platform') == 'pacbio':
            exp_class = experiment.PacbioExperiment
        elif description.get('platform') == 'illumina':
            exp_class = experiment.IlluminaExperiment
        else:
            raise ValueError(description)

    exp = exp_class(args.project_directory, args.group, args.sample, description, args.progress)

    stages = args.stages.split(',')
    for stage in stages:
        exp.process(stage)

def make_tables(args):
    from knock_knock import experiment, table

    results_dir = args.project_directory / 'results'

    if args.group:
        groups = [args.group]
    else:
        groups = experiment.get_all_groups(args.project_directory)

        results_dir = args.project_directory / 'results'
        csv_fn = (results_dir / 'all_groups').with_suffix('.csv')
        df = table.load_counts(args.project_directory, exclude_empty=False).T
        df.to_csv(csv_fn)

    for group in groups:
        print(group)

        conditions = {'group': group}

        fns_to_zip = []

        print('Generating high-level html table...')
        html_fn = (results_dir / group).with_suffix('.html')
        table.generate_html(args.project_directory, html_fn, conditions, show_details=False)
        fns_to_zip.append(html_fn)

        print('Generating detailed html table...')
        html_fn = (results_dir / f'{group}_with_details').with_suffix('.html')
        table.generate_html(args.project_directory, html_fn, conditions, show_details=True)
        fns_to_zip.append(html_fn)

        print('Generating csv table...')
        csv_fn = (results_dir / group).with_suffix('.csv')
        df = table.load_counts(args.project_directory, conditions, exclude_empty=False).T
        df.to_csv(csv_fn)
        fns_to_zip.append(csv_fn)
        
        print('Generating performance metrics...')
        pms_fn = (results_dir / f'{group}_performance_metrics').with_suffix('.csv')
        pms = table.calculate_performance_metrics(args.project_directory, conditions)
        pms.to_csv(pms_fn)
        fns_to_zip.append(pms_fn)

        exps = experiment.get_all_experiments(args.project_directory, conditions)

        exps_missing_files = set()

        for exp in exps:
            def add_fn(fn):
                if not fn.exists():
                    exps_missing_files.add((exp.group, exp.name))
                else:
                    if fn.is_dir():
                        for child_fn in fn.iterdir():
                            fns_to_zip.append(child_fn)
                    else:
                        fns_to_zip.append(fn)
            
            add_fn(exp.fns['outcome_browser'])
            add_fn(exp.fns['lengths_figure'])

            for outcome in exp.outcomes:
                outcome_fns = exp.outcome_fns(outcome)
                add_fn(outcome_fns['diagrams_html'])
                add_fn(outcome_fns['first_example'])
                add_fn(outcome_fns['length_ranges_dir'])

            categories = set(c for c, s in exp.outcomes)
            for category in categories:
                outcome_fns = exp.outcome_fns(category)
                add_fn(outcome_fns['diagrams_html'])
                add_fn(outcome_fns['first_example'])

        if exps_missing_files:
            print(f'Warning: {len(exps_missing_files)} experiment(s) are missing output files:')
            for group, exp_name in sorted(exps_missing_files):
                print(f'\t{group} {exp_name}')

        zip_fn = (results_dir / group).with_suffix('.zip')
        archive_base = Path(group)
        with zipfile.ZipFile(zip_fn, mode='w', compression=zipfile.ZIP_DEFLATED) as zip_fh:
            description = 'Zipping table files'
            for fn in tqdm.tqdm(fns_to_zip, desc=description):
                arcname = archive_base / fn.relative_to(results_dir)
                zip_fh.write(fn, arcname=arcname)

def build_targets(args):
    from knock_knock import build_targets
    build_targets.build_target_infos_from_csv(args.project_directory)

def design_primers(args):
    from knock_knock import build_targets
    build_targets.design_amplicon_primers_from_csv(args.project_directory, args.genome)

def build_indices(args):
    from knock_knock import build_targets
    build_targets.download_genome_and_build_indices(args.project_directory, args.genome_name, args.num_threads)

def install_example_data(args):
    import os
    import shutil
    import knock_knock

    package_dir = Path(os.path.realpath(knock_knock.__file__)).parent / 'example_data'
    subdirs_to_copy = ['data', 'targets']
    for subdir in subdirs_to_copy:
        src = package_dir / subdir
        dest = args.project_directory / subdir

        if dest.exists():
            print(f'Can\'t install to {str(args.project_directory)}, {str(dest)} already exists')
            sys.exit(0)

        shutil.copytree(str(src), str(dest))

    print(f'Example data installed in {str(args.project_directory)}')

def print_citation(args):
    print(citation)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog='knock-knock')

    parser.add_argument('--version', action='version', version=knock_knock.__version__)

    subparsers = parser.add_subparsers(dest='subcommand', title='subcommands')
    subparsers.required = True

    parser_process = subparsers.add_parser('process', help='process a single sample')
    parser_process.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_process.add_argument('group', help='group name')
    parser_process.add_argument('sample', help='sample name')
    parser_process.add_argument('--progress', const=tqdm.tqdm, action='store_const', help='show progress bars')
    parser_process.add_argument('--stages', default='align,categorize,visualize')
    parser_process.set_defaults(func=process)

    parser_parallel = subparsers.add_parser('parallel', help='process multiple samples in parallel')
    parser_parallel.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_parallel.add_argument('max_procs', type=int, help='maximum number of samples to process at once')
    parser_parallel.add_argument('--group', help='if specified, the single group name to process; if not specified, all groups will be processed')
    parser_parallel.add_argument('--conditions', type=yaml.safe_load, default={}, help='if specified, conditions that samples must satisfy to be processed, given as yaml; if not specified, all samples will be processed')
    parser_parallel.add_argument('--stages', default='align,categorize,visualize')
    parser_parallel.set_defaults(func=parallel)

    parser_table = subparsers.add_parser('table', help='generate tables of outcome frequencies')
    parser_table.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_table.add_argument('--group', help='if specified, the single group name to generate tables for; if not specified, all groups will be generated')
    parser_table.set_defaults(func=make_tables)

    parser_targets = subparsers.add_parser('build_targets', help='build annotations of target locii')
    parser_targets.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_targets.set_defaults(func=build_targets)

    parser_primers = subparsers.add_parser('design_primers', help='design amplicon primers for sgRNAs')
    parser_primers.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_primers.set_defaults(func=design_primers)

    parser_indices = subparsers.add_parser('build_indices', help='download a reference genome and build alignment indices')
    parser_indices.add_argument('project_directory', type=Path, help='the base directory to store input data, reference annotations, and analysis output for a project')
    parser_indices.add_argument('genome_name', help='name of genome to download')
    parser_indices.add_argument('--num_threads', type=int, default=8, help='number of threads to use for index building')
    parser_indices.set_defaults(func=build_indices)

    parser_install_data = subparsers.add_parser('install_example_data', help='install example data into user-specified project directory')
    parser_install_data.add_argument('project_directory', type=Path, help='directory to install example data')
    parser_install_data.set_defaults(func=install_example_data)

    parser_citation = subparsers.add_parser('whos_there', help='print citation information')
    parser_citation.set_defaults(func=print_citation)

    args = parser.parse_args()
    args.func(args)
