#! /usr/bin/env python3

import os
import argparse

import pandas as pd

from fabric.util import ALL_NTS_SET, log
from fabric.common_args import get_parser_file_type, get_parser_directory_type
from fabric.gene_score_dist import GeneScoreModel
from fabric.analysis import analyze_variant_effects

def validate_and_transform_effect_scores(possible_effects, lower_effect_scores_are_more_damaging):

    assert possible_effects['effect_score'].between(0, 1).all(), 'Effect scores must be between 0 to 1.'
    
    if lower_effect_scores_are_more_damaging:
        assert (possible_effects.loc[possible_effects['effect_type'] == 'synonymous', 'effect_score'] == 1).all(), 'When --lower-effect-scores-are-more-damaging is provided, ' + \
                'all synonymous variants must have an effect score of 1.'
        assert (possible_effects.loc[possible_effects['effect_type'] == 'nonsense', 'effect_score'] == 0).all(), 'When --lower-effect-scores-are-more-damaging is provided, ' + \
                'all nonsense variants must have an effect score of 0.'
    else:
        assert (possible_effects.loc[possible_effects['effect_type'] == 'synonymous', 'effect_score'] == 0).all(), 'All synonymous variants must have an effect score of 0.'
        assert (possible_effects.loc[possible_effects['effect_type'] == 'nonsense', 'effect_score'] == 1).all(), 'All nonsense variants must have an effect score of 1.'
        possible_effects['effect_score'] = 1 - possible_effects['effect_score']
        
def calculate_gene_cds_lens(possible_effects):
    assert (possible_effects.groupby(['gene_index', 'pos']).size() == 3).all(), 'Every gene is expected to have exactly 3 possible effects listed at every position ' + \
            '(corresponding to each reference nucleotide substituting into each of the other three possible nucleotides).'
    return possible_effects.groupby('gene_index')['pos'].nunique()
    
def parse_chrom(chrom):
    if chrom.lower().startswith('chr'):
        return chrom[3:]
    else:
        return chrom

if __name__ == '__main__':

    # Parse and validate the arguments.

    parser = argparse.ArgumentParser(description = 'Run FABRIC on a list of input variants, calculating the alteration bias in each gene independently.')
    parser.add_argument('--input-variants-file', dest = 'input_variants_file', metavar = '/path/to/input_variants.csv', type = get_parser_file_type(parser, must_exist = True), \
            required = True, help = 'Path to the list of input variants to analyze (a CSV file).')
    parser.add_argument('--possible-variant-effects-file', dest = 'possible_variant_effects_file', metavar = '/path/to/possible_variant_effects.csv', \
            type = get_parser_file_type(parser, must_exist = True), required = True, help = 'Path to the list of all possible variant effects, for the input variants ' + \
            'to be analyzed against. This file is expected to list all possible CDS-affecting single-nucleotide variants. It is expected to have the following columns: ' + \
            'gene_index, chrom, pos, ref, alt, effect_type, effect_score. See the documentation for instructions on how to obtain or generate this file. Make sure that ' + \
            'the varaints listed in this file are of the same version of the reference genome as the input variants.')
    parser.add_argument('--genes-file', dest = 'genes_file', metavar = '/path/to/genes.csv', type = get_parser_file_type(parser, must_exist = True), \
            required = True, help = 'A CSV file with meta-data about all the genes referenced by --possible-variant-effects-file (the gene_inex in this file is expected to ' + \
            'match the first column in --genes-file). The columns in this file will be included in the output CSV(s).') 
    parser.add_argument('--output-file', dest = 'output_file', metavar = '/path/to/output.csv', \
            type = get_parser_file_type(parser), required = False, help = 'The file to save the analysis output into. This argument should be given only if there is a ' + \
            'single output file. If --analyze-by-project-column is given, use --output-dir instead.')
    parser.add_argument('--output-dir', dest = 'output_dir', metavar = '/path/to/output_dir', type = get_parser_directory_type(parser, create_if_not_exists = True), \
            required = False, help = 'The directory where the analysis output files will be saved. Will create the directory if it doesn\'t exist.')
    parser.add_argument('--override', dest = 'override', action = 'store_true', help = 'Allow FABRIC to override output files.')
    parser.add_argument('--analyze-by-project-column', dest = 'analyze_by_project_column', metavar = 'project', type = str, required = False, help = 'A column name in the ' + \
            'file of input variants to be regarded as a project. If provided, FABRIC will also analyze the input variants for each project independently, creating an output ' + \
            'file for each project.')
    parser.add_argument('--skip-combined-analysis', dest = 'skip_combined_analysis', action = 'store_true', help = 'By default, even when --analyze-by-project-column is ' + \
            'provided, FABRIC will also run a combined analysis of all the variants together. If this flag is provided, FABRIC will skip that combined analysis.')
    parser.add_argument('--analyze-diff', dest = 'analyze_diff', action = 'store_true', help = 'Also analyze the difference in alteration bias among ' + \
            'projects, resulting additional output columns with the results of this analysis.')
    parser.add_argument('--only-gene-indices', dest = 'only_gene_indices', metavar = '<index>', nargs = '+', type = int, required = False, \
            help = 'Limit the analysis only to the provided genes (identified by their index). If not provided, will analyze all relevant genes (for which there is at ' + \
            'least one input variant with a possible variant effect).')
    parser.add_argument('--input-variants-tab-delimiter', dest = 'input_variants_tab_delimiter', action = 'store_true',  help = 'By default, --input-variants-file is assumed ' + \
            'to be comma-delimited. Use this flag if it is tab-delimited.')
    parser.add_argument('--chrom-column', dest = 'chrom_column', metavar = 'chrom', type = str, default = 'chrom', help = 'The name of the column in the file of input variants ' + \
            'specifying the chromosome of each variant (default is "chrom").')
    parser.add_argument('--pos-column', dest = 'pos_column', metavar = 'pos', type = str, default = 'pos', help = 'The name of the column in the file of input variants ' + \
            'specifying the position of each variant (default is "pos").')
    parser.add_argument('--ref-column', dest = 'ref_column', metavar = 'ref', type = str, default = 'ref', help = 'The name of the column in the file of input variants ' + \
            'specifying the reference sequence of each variant (default is "ref").')
    parser.add_argument('--alt-column', dest = 'alt_column', metavar = 'alt', type = str, default = 'alt', help = 'The name of the column in the file of input variants ' + \
            'specifying the alternative sequence of each variant (default is "alt").')
    parser.add_argument('--lower-effect-scores-are-more-damaging', dest = 'lower_effect_scores_are_more_damaging', action = 'store_true',  help = 'By default, FABRIC assumes ' + \
            'that the effect scores provided in the list of possible variant effects are such that synonymous (non-damaging) variants have an effect score of 0, nonsense ' + \
            '(damaging) variants have an effect score of 1, and missense variants have an effect score between 0 to 1, where higher effect scores correspond to higher ' + \
            'probability of damage. However, the algorithm and outputs of FABRIC actually work the opposite way: higher scores correspond to less predicted damage. Therefore, ' + \
            'before analyzing the variants, FABRIC applies the transformation x -> 1 - x on the provided effect scores. If the provided effect scores are already such that ' + \
            'lower effect scores indicate more damaging variants, then you need to provide this flag.')
    args = parser.parse_args()
    
    if int(args.output_file is None) + int(args.output_dir is None) != 1:
        parser.error('Exactly one of --output-file or --output-dir should be given.')
        
    if not args.override:
        
        if args.output_file is not None and os.path.exists(args.output_file):
            parser.error('Output file already exists: %s. Consider using the --override flag.' % args.output_file)
            
        if args.output_dir is not None and len(os.listdir(args.output_dir)) > 0:
            parser.error('Output directory is not empty: %s. Consider using the --override flag.' % args.output_dir)
            
    if args.analyze_by_project_column is not None and args.output_dir is None:
        parser.error('When --analyze-by-project-column is provided, --output-dir should be given as well (rather than --output-file).')
            
    if args.skip_combined_analysis and args.analyze_by_project_column is None:
        parser.error('Cannot skip the combined analysis if --analyze-by-project-column is not provided. There is nothing to analyze. Either provide a project column, or ' + \
                'remove the --skip-combined-analysis flag.')
                
    if args.analyze_diff and args.analyze_by_project_column is None:
        parser.error('--analyze-diff requires a project columns to be set by --analyze-by-project-column.')
    
    
    # Load and process the data.
    
    input_variants_id_columns = [args.chrom_column, args.pos_column, args.ref_column, args.alt_column]
    input_variants_relevant_columns = input_variants_id_columns if args.analyze_by_project_column is None else input_variants_id_columns + [args.analyze_by_project_column]
    possible_effects_id_columns = ['chrom', 'pos', 'ref', 'alt']
    possible_effects_relevant_columns = ['gene_index'] + possible_effects_id_columns + ['effect_type', 'effect_score']
    
    log('Loading input variants...')
    input_variants = pd.read_csv(args.input_variants_file, delimiter = '\t' if args.input_variants_tab_delimiter else ',', usecols = input_variants_relevant_columns,
            dtype = {args.chrom_column: str}, engine = 'python')
    input_variants[args.chrom_column] = input_variants[args.chrom_column].apply(parse_chrom)
    log('Loaded %d input variants.' % len(input_variants))
    
    log('Loading possible variant effects...')
    possible_effects = pd.read_csv(args.possible_variant_effects_file, usecols = possible_effects_relevant_columns, dtype = {'chrom': str})
    log('Loaded %d possible variant effects.' % len(possible_effects))
        
    log('Loading genes...')
    genes = pd.read_csv(args.genes_file, index_col = 0)
    log('Loaded %d genes.' % len(genes))
    
    assert set(possible_effects['ref'].unique()) <= ALL_NTS_SET and set(possible_effects['alt'].unique()) <= ALL_NTS_SET, 'All possible variant effects are expected to be ' + \
            'single-nucleotide substitutions.'
    assert set(possible_effects['effect_type'].unique()) <= {'synonymous', 'missense', 'nonsense'}, 'Effect types must be either of "synonymous", "missense" and "nonsense".'
    validate_and_transform_effect_scores(possible_effects, args.lower_effect_scores_are_more_damaging)
    assert set(possible_effects['gene_index'].unique()) <= set(genes.index), 'Gene indices must be included in the (first) index column of the genes file.'
    
    if 'cds_len' in genes.columns:
        log('Warning: overriding the "cds_len" column in the genes table with those derived from the possible variant effects.')
    
    log('Calculating the lengths of the genes based on the list of possible variants...')
    genes['cds_len'] = calculate_gene_cds_lens(possible_effects).reindex(genes.index)
        
    log('Mapping the input variants (identified by the columns %s) against the possible variant effects (identified by the columns %s)...' % \
            (input_variants_id_columns, possible_effects_id_columns))
    input_variant_effects = input_variants.merge(possible_effects, left_on = input_variants_id_columns, right_on = possible_effects_id_columns, validate = 'm:m')
    log('Mapped the %d input variants into %d relevant variant effects. Note that each input variant could potentially affect multiple genes and result in multiple effects.' % \
            (len(input_variants), len(input_variant_effects)))
    assert len(input_variant_effects) > 0, "Ended up without any mapped variant effects. Something is probably wrong with the inputs."
    del input_variants
    
    if args.only_gene_indices is not None:
        input_variant_effects = input_variant_effects[input_variant_effects['gene_index'].isin(set(args.only_gene_indices))]
        log('Restricting the anlysis to the %d genes provided by --only-gene-indices, %d input variant effects are still relevant.' % (len(args.only_gene_indices, \
                len(input_variant_effects))))
    
    relevant_gene_indices = set(input_variant_effects['gene_index'].unique())
    
    if args.only_gene_indices is not None and len(relevant_gene_indices) < len(args.only_gene_indices):
        log('Only %d of the provided gene indices exit among the input variant effects.' % len(relevant_gene_indices))
    
    log('Will analyze %d relevant genes (that are affected by at least one input variant).' % len(relevant_gene_indices))
    assert len(relevant_gene_indices) > 0, "There are no genes to analyze."
    possible_effects = possible_effects[possible_effects['gene_index'].isin(relevant_gene_indices)]
    
    log('Constructing background models for the %d relevant genes based on the possible variant effects...' % len(relevant_gene_indices))
    gene_score_models = {gene_index: GeneScoreModel(gene_possible_variants) for gene_index, gene_possible_variants in possible_effects.groupby('gene_index')}
    del possible_effects
    
    
    # Run the analysis.
    
    if args.analyze_by_project_column:
        for project, project_input_variant_effects in input_variant_effects.groupby(args.analyze_by_project_column):
            log('Analyzing project "%s" (%d input variant effects)...' % (project, len(project_input_variant_effects)))
            other_variant_effects = input_variant_effects[input_variant_effects[args.analyze_by_project_column] != project] if args.analyze_diff else None
            project_results = analyze_variant_effects(project_input_variant_effects, genes, gene_score_models, other_variant_effects)
            log('%s: %d significant genes.' % (project, project_results['overall_fdr_significance'].sum()))
            project_results.to_csv(os.path.join(args.output_dir, '%s.csv' % project))
                    
    if not args.skip_combined_analysis:
        log('Running the combined analysis on all %d input variant effects...' % len(input_variant_effects))
        combined_results = analyze_variant_effects(input_variant_effects, genes, gene_score_models)
        log('Combined analysis: %d significant genes.' % combined_results['overall_fdr_significance'].sum())
        combined_output_file_path = os.path.join(args.output_dir, 'combined.csv') if args.output_file is None else args.output_file
        combined_results.to_csv(combined_output_file_path)
        
    log('Done.')
