#!python

'''
input:  a collection of per-sample coverage TSVs
output: single giant 2D data frame with all samples
'''

import os
import sys

import argparse
from multiprocessing import Pool
import pandas as pd

__author__ = "Ivan Blagoev Topolsky"
__copyright__ = "Copyright 2020"
__credits__ = "Ivan Blagoev Topolsky"
__license__ = "GPL2+"
__maintainer__ = "Ivan Blagoev Topolsky"
__email__ = "v-pipe@bsse.ethz.ch"
__status__ = "Development"



def loader(tsvname):
    """ load a single per-sample TSV file """
    return pd.read_csv(tsvname, sep="\t",compression='infer',
                       index_col=['ref','pos'],#dtype='uint32',
                       low_memory=True,memory_map=True)

def parse_args():
    """ Set up the parsing of command-line arguments """
    parser = argparse.ArgumentParser(fromfile_prefix_chars='@', # support pass BAM list in file instead of parameters
                                     description="Script to gather multiple per sample coverage information into a single unified file",
                                     epilog="input and output TSVs support compression.\n@listfile can be used to pass a long list of TSV in a file instead of command line")
    parser.add_argument('-o', '--output', metavar='TSV', required=True,
                        default='coverage.tsv.gz',
                        type=str, dest='coverage', help='unified coverage table output file')
    parser.add_argument('-s', '--stats', metavar='TSV', required=False,
                        type=str, dest='stats', help="file to write per-position cohort-wide stats to")
    parser.add_argument('-t', '--threads', metavar='NCPUS', required=False,
                        default=4,
                        type=int, dest='threads', help="number of threads")
    parser.add_argument("INPUT", nargs='+', metavar='TSV', help="per sample coverage table input file(s)")

    return parser.parse_args()


def main():
    """ use a thread pool to load all TSV file in parallel, then try to zero-copy into a single unified dataframe """
    args = parse_args()

    print(f'Threads: {args.threads}', file=sys.stderr)
    print(f'Gathering {len(args.INPUT)} files...', file=sys.stderr)
    with Pool(processes=args.threads) as process_pool:
        dfs = process_pool.map(loader, args.INPUT)
    # Concat dataframes to one dataframe
    print('merging...', file=sys.stderr)
    coverage = pd.concat(dfs, axis=1, ignore_index=False, join='outer', copy=False)

    print('done.', file=sys.stderr)

    # save the TSV files
    coverage.to_csv(args.coverage, sep="\t", compression={'method':'infer'})

    # save cohort-wide per position stats
    if args.stats:
        coverage.apply(pd.DataFrame.describe, axis=1).to_csv(
            args.stats, sep="\t", compression={'method':'infer'})


if __name__ == '__main__':
    main()
