#! /usr/bin/env python3

import os
import re
import json
import argparse

import numpy as np
import pandas as pd

from pwas.statistics import GeneTest, BinaryTraitGeneScoreTest, ContinuousTraitGeneScoreTest
from pwas.shared_utils.util import log, get_recognized_files_in_dir, get_chunk_slice, get_parser_file_type, get_parser_directory_type, \
        add_parser_task_arguments, determine_parser_task_details, is_binary_series
        
def determine_covariate_cols(args):

    covariate_cols = set(args.covariate_cols)
    
    if args.covariate_cols_json_file is not None:
        with open(args.covariate_cols_json_file, 'r') as f:
            covariate_cols |= set(json.load(f))
            
    if covariate_cols:
        log('%d covariates will be included.' % len(covariate_cols))
    else:
        log('Warning: No covariates were provided! Are you sure about that?!')
        
    return list(sorted(covariate_cols))
    
def get_gene_effect_score_files(gene_effect_scores_dir):

    GENE_FILE_NAME_PATTERN = re.compile(r'^(\d+)\.csv$')
    
    def parse_gene_index(file_name):
        raw_gene_index, = GENE_FILE_NAME_PATTERN.findall(file_name)
        return int(raw_gene_index)
    
    return get_recognized_files_in_dir(gene_effect_scores_dir, parse_gene_index)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description = 'Run the PWAS statistical tests per gene (looking for associations between a phenotype to gene ' + \
            'effect scores, while potentially accounting for covariates).')
    parser.add_argument('--dataset-file', dest = 'dataset_file', metavar = '/path/to/dataset.csv/', type = get_parser_file_type(parser, must_exist = True), \
            required = True, help = 'Path to the dataset CSV file of the cohort (a row per sample, with eid, phenotype and covariate columns).')
    parser.add_argument('--gene-effect-scores-dir', dest = 'gene_effect_scores_dir', metavar = '/path/to/gene_effect_scores/', \
            type = get_parser_directory_type(parser), required = True, help = 'The directory with the gene effect scores (with a separate CSV file ' + \
            'per gene).')
    parser.add_argument('--per-gene-pwas-results-dir', dest = 'per_gene_pwas_results_dir', metavar = '/path/to/output_per_gene_pwas_results/', \
            type = get_parser_directory_type(parser), required = True, help = 'The directory in which the results will be written (a CSV file per ' + \
            'gene, with the relevant summary statistics).')
    parser.add_argument('--sample-id-col', dest = 'sample_id_col', metavar = '<COL_NAME>', type = str, required = True, \
            help = 'The name of the column (within the dataset CSV file) listing the identifier of each sample (which is expected to match to the ' + \
            '"sample_id" column in the gene effect score CSV files).')
    parser.add_argument('--phenotype-col', dest = 'phenotype_col', metavar = '<COL_NAME>', type = str, required = True, \
            help = 'The name of the column (within the dataset CSV file) for the phenotype to be tested. The phenotype can be either continuous or ' + \
            'binary (will be automatically detected if the values are only 0s and 1s).')
    parser.add_argument('--covariate-cols', dest = 'covariate_cols', metavar = '<COL_NAME1>,<COL_NAME2>,...', type = int, nargs = '+', default = [], \
            help = 'Columns (within the dataset CSV file) to be considered as covariates. See also --covariate-cols-json-file.')
    parser.add_argument('--covariate-cols-json-file', dest = 'covariate_cols_json_file', metavar = '/path/to/covariate_columns.json', \
            type = get_parser_file_type(parser, must_exist = True), default = None, help = 'An optional JSON file with a list of column names to ' + \
            'consider as covariates. If both --covariate-cols and --covariate-cols-json-file are provided, will take the union of the provided column names.')
    add_parser_task_arguments(parser)
    args = parser.parse_args()
        
    task_index, total_tasks = determine_parser_task_details(args)
    covariate_cols = determine_covariate_cols(args)
        
    dataset = pd.read_csv(args.dataset_file, index_col = args.sample_id_col)
    log('Loaded the dataset (%s): %d samples x %d columns.' % (args.dataset_file, len(dataset), len(dataset.columns)))
    dataset.dropna(subset = [args.phenotype_col] + covariate_cols, inplace = True)
    log('Filtered the dataset into %d samples with the phenotype and all covariates.' % len(dataset))
    
    phenotype_values = dataset[args.phenotype_col]
    covariates = dataset[covariate_cols]
        
    if is_binary_series(phenotype_values):
        log('The phenotype (%s) was determined to be binary (%d cases).' % (args.phenotype_col, phenotype_values.sum()))
        score_test_class = BinaryTraitGeneScoreTest
    else:
        log('The phenotype (%s) was determined to be continuous.' % args.phenotype_col)
        score_test_class = ContinuousTraitGeneScoreTest
        
    if total_tasks != 1:
        log('Running task %d/%d.' % (task_index, total_tasks))
        
    all_gene_effect_score_files = get_gene_effect_score_files(args.gene_effect_scores_dir)
    n_total_genes = len(all_gene_effect_score_files)
    gene_start, gene_end = get_chunk_slice(n_total_genes, total_tasks, task_index)
    log('Will run on genes %d-%d/%d.' % (gene_start, gene_end, n_total_genes))
    
    for i, (gene_index, gene_effect_score_file_name) in enumerate(all_gene_effect_score_files[gene_start:gene_end]):
    
        gene_effect_scores = pd.read_csv(os.path.join(args.gene_effect_scores_dir, gene_effect_score_file_name))
        gene_sample_ids = set(gene_effect_scores['sample_id'])
        gene_relevant_sample_ids = np.array([sample_id for sample_id in dataset.index if sample_id in gene_sample_ids])
        log('Analyzing gene %d/%d (#%d; %d of %d samples are in the dataset)...' % (i, gene_end - gene_start, gene_index, \
                len(gene_relevant_sample_ids), len(gene_sample_ids)))
                
        gene_relevant_phenotype_values = phenotype_values.loc[gene_relevant_sample_ids]
        gene_relevant_covariates = covariates.loc[gene_relevant_sample_ids]
        gene_relevant_effect_scores = gene_effect_scores.set_index('sample_id').loc[gene_relevant_sample_ids]
        
        gene_results = GeneTest(gene_relevant_phenotype_values, gene_relevant_covariates, gene_relevant_effect_scores['dominant'], \
                gene_relevant_effect_scores['recessive'], score_test_class).run()
        gene_results.to_csv(os.path.join(args.per_gene_pwas_results_dir, '%d.csv' % gene_index), header = False)

    log('Done.')