#!python

'''
input:  alignment file. An index file should exists in the same directory.
output: tab-separated files.
        alignmnt position and references on the rows and bases as columns.
        Loci are reported using 0-based indexing.
'''

import os
import sys

import pysam
import argparse
import configparser
import yaml
import numpy as np
import pandas as pd


from smallgenomeutilities.__pileup__ import get_cnt_matrix

__author__ = "Ivan Blagoev Topolsky"
__copyright__ = "Copyright 2020"
__credits__ = "Ivan Blagoev Topolsky"
__license__ = "GPL2+"
__maintainer__ = "Ivan Blagoev Topolsky"
__email__ = "v-pipe@bsse.ethz.ch"
__status__ = "Development"



def parse_args():
    """ Set up the parsing of command-line arguments """
    parser = argparse.ArgumentParser(description="Script to extract base counts and coverage information from a single alignment file",
                                     epilog="output TSVs support compression")
    parser.add_argument('-b', '--basecnt', metavar='TSV', required=True,
                        default='basecnt.tsv.gz',
                        type=str, dest='basecnt', help='bases count table output file')
    parser.add_argument('-c', '--coverage', metavar='TSV', required=True,
                        default='coverage.tsv.gz',
                        type=str, dest='coverage', help='coverage table output file')
    parser.add_argument('-N', '--name', metavar='patient-sample', required=False,
                        default='coverage',
                        type=str, dest='name', help="Patient/sample identifiers to use in coverage column title instead of 'coverage'")
    parser.add_argument('-A', '--alphabet', metavar='alphabet', required=False,
                        default='ACGT-',
                        type=str, dest='alpha', help='alphabet to use')
    parser.add_argument('-f', '--first', metavar='ARRAY_BASE', required=False,
                        default=0, choices=[0,1],
                        type=int, dest='first', help='select whether the first position is named "0" (standard for python tools such as pysam, older versions of smallgenomeutilities, and the BED format) or "1" (standard scientific notation used in most tools, and most text formats such as VCF and GFF)')
    parser.add_argument('-s', '--stats', metavar='YAML/JSON/INI', required=False,
                        type=str, dest='stats', help="file to write stats to")
    parser.add_argument("FILE", nargs=1, metavar='BAM/CRAM', help="alignment file")

    return parser.parse_args()


def main():
    args = parse_args()

    bamfile=args.FILE[0]
    # some useful stats
    reads=0
    '''readcounts'''
    instot=0
    '''total template_length    ( = insert_size * reads count )'''
    rltot=0
    '''total read bases         ( = read_len    * reads count )'''

    # data frames holding the TSV data
    cols = pd.MultiIndex.from_product([[args.name],list(args.alpha)],names=['sample', 'nt'])
    basecnt=pd.DataFrame(columns=cols)
    coverage=pd.DataFrame(columns=[args.name])
    with pysam.AlignmentFile(bamfile, 'rc' if os.path.splitext(bamfile)[1] == '.cram' else 'rb') as alnfile:
        for reference_name in alnfile.references:
            region_len=alnfile.get_reference_length(reference_name)
            print(f"\r{reference_name} [{region_len}]... \033[0K", end='', file=sys.stderr)

            nt_counts, nr, ins, rl = get_cnt_matrix(alnfile, reference_name,alpha=args.alpha)

            index = pd.MultiIndex.from_product([[reference_name],np.arange(args.first,region_len+args.first)],names=['ref', 'pos'])
            cnt_df = pd.DataFrame(data=nt_counts, index=index, columns=cols)
            basecnt = pd.concat([basecnt, cnt_df], copy=False) if reads > 0 else cnt_df

            # NOTE the following will include all nt from the alphabet (including e.g.: `-`) and skip any non listed (e.g.: `N`)
            np_coverage = np.sum(nt_counts, axis=1)
            cov_df = pd.DataFrame(data=np_coverage, index=index, columns=[args.name])
            coverage = pd.concat([coverage, cov_df], copy=False) if reads > 0 else cov_df

            # gather useful stats
            reads+=nr
            instot+=ins
            rltot+=rl
        print(f"\rdone.\033[0K", file=sys.stderr)

    # save the TSV files
    coverage.to_csv(args.coverage, sep="\t", compression={'method':'infer'})
    basecnt.to_csv(args.basecnt, sep="\t", compression={'method':'infer'})

    if args.stats:
        # basesums, for GC content
        GC_frac=0.0
        if reads:
            basesum=basecnt.sum(axis=0)[args.name]
            GC=basesum['G']+basesum['C']
            AT=basesum['A']+basesum['T']
            bstot=GC+AT
            GC_frac = (GC / bstot) if bstot else 0.0

        # append the useful stats to the summary desciption
        desc=pd.concat([
                coverage.describe(include=[np.number]),
                pd.DataFrame({args.name:  {'reads':reads,'insert_mean':(instot / reads),'readlen_mean':(rltot / reads),'GC':GC_frac}})
            ]) if reads else pd.DataFrame({args.name:
                # only zeros if no reads at all
                {k:0.0 for k in ('count','mean','std','25%','50%','75%','max','reads','insert_mean','readlen_mean','GC')}})
        # detect format
        ext = os.path.splitext(args.stats)[1]
        if ext == '.yaml':
            with open(args.stats, 'w') as yf:
                print(yaml.dump(desc.to_dict()), file=yf)
        elif ext == '.ini':
            ini = configparser.ConfigParser()
            ini.read_dict(desc.to_dict())
            with open(args.stats, 'w') as inif:
                ini.write(inif)
        else:
            desc.to_json(args.stats)

if __name__ == '__main__':
    main()
