#!/usr/bin/env python

from SomVarIUS_calling import write_mutations
import time
import argparse
import os, sys
import pysam
import cPickle
from assign_clones import clones
from query_muts import query_mutations


def query_given_muts(args):
    
    muts_fn = args.muts
    samfile_name = args.bam
    min_mapq = args.min_mapq
    min_baseq = args.min_baseq
    min_support = args.min_support
    min_reads = args.min_reads
    out_fn = args.out
    
    out_file = open(out_fn, 'w')
    out_file.write('##fileformat=VCFv4.0\n')
    out_file.write('##source=SomVarIUS_V1.1\n')
    out_file.write('##INFO=<ID=NS,Number=1,Type=Integer,Description="Number of Samples With Data">\n')
    out_file.write('##INFO=<ID=DP,Number=1,Type=Integer,Description="Total Depth">\n')
    out_file.write('##INFO=<ID=AF,Number=.,Type=Float,Description="Allele Frequency">\n')
    out_file.write('##INFO=<ID=P,Number=.,Type=Float,Description="Combined Pvalue">\n')
    out_file.write('##INFO=<ID=SP,Number=.,Type=Float,Description="Sequencing Error Pvalue">\n')
    out_file.write('##INFO=<ID=GP,Number=.,Type=Float,Description="Germline Mutation Pvalue">\n')
    out_file.write('##INFO=<ID=MQ,Number=.,Type=Float,Description="Mean Mapping Quality">\n')
    out_file.write('##INFO=<ID=BQ,Number=.,Type=Float,Description="Mean Alternate Base Quality">\n')
    out_file.write('##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n')
    out_file.write('##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">\n')
    out_file.write('##FORMAT=<ID=DP,Number=1,Type=Integer,Description="Read Depth">\n')
    out_file.close()
    
    print 'Querying mutations'
    query_mutations(muts_fn, samfile_name, min_mapq, min_baseq, min_reads,
                    min_support, out_fn)


def dbsnp_pickle(args):
    
    dbsnpfile_name = args.dbsnp
    out_pickle = args.dbsnp_out

    germ_pos = {}
    for line in open(dbsnpfile_name, 'r'):
        fields = line.strip().split('\t')
        chrom = fields[0]
        pos = int(fields[1])
        ref = fields[3]
        alt = fields[4]
        try:
            germ_pos[(chrom, pos)] = (ref, alt)
        except KeyError:
            germ_pos[(chrom, pos)] = {}

    cPickle.dump(germ_pos, open(out_pickle,'w'))


def sort_files(args):
    
    samfile_name = args.bam
    dbsnpfile_name = args.dbsnp
    output_samfile = args.bam_out
    output_dbsnp = args.dbsnp_out
    
    pysam.sort("-n", samfile_name, output_samfile)
    dbsnp_lines = []
    for line in open(dbsnpfile_name, 'r'):
        dbsnp_lines.append(line.strip().split('\t'))
        dbsnp_lines.sort()
    sorted_dbsnp = open(output_dbsnp,'w')
    for line in dbsnp_lines:
        sorted_dbsnp.write('\t'.join(line)+'\n')
        sorted_dbsnp.close()


def get_mutations(args):
    
    chrom = args.chrom
    start = args.start
    if start != None:
        start = int(start)
    end = args.end
    if end != None:
        end = int(end)
    out = args.out
    rna_seq = args.rna_seq
    germ_pos = args.germ_pos
    min_reads = args.min_reads
    min_support = args.min_support
    min_af = args.min_af
    min_pvalue = args.min_pvalue
    min_fr = args.min_fr
    min_qual = args.min_qual
    min_se = args.min_se
    min_hetero = args.min_hetero
    ref_filter = args.ref_filter
    min_mapq = args.min_mapq
    min_baseq = args.min_baseq
    dbsnp_bed_name = args.dbsnp_bed
    copy_bed_name = args.copy_bed
    rna_filter = False
    hap_filter = False
    binom = args.binom
    dist = args.dist
    
    args_file_name = out.split('.')[0] + '_args.txt'
    args_file = open(args_file_name,mode='w')
    fmt = '--bam {args.bam} --ref {args.ref} --out {args.out} '
    fmt += '--germ_pos {args.germ_pos} --dbsnp_bed {args.dbsnp_bed} '
    fmt += '--min_reads {args.min_reads} --min_support {args.min_support} '
    fmt += '--min_af {args.min_af} --min_pvalue {args.min_pvalue} --min_fr {args.min_fr} '
    fmt += '--min_qual {args.min_qual} --min_se {args.min_se} '
    fmt += '--min_hetero {args.min_hetero} --min_mapq {args.min_mapq} '
    fmt += '--ref_filter {args.ref_filter} --dbsnp_bed {args.dbsnp_bed} '
    fmt += '--min_baseq {args.min_baseq} --binom {args.binom}'
    
    if copy_bed_name != '':
        fmt += ' --copy_bed {args.copy_bed}'
    if chrom != None:
        fmt += 'chrom {args.chrom}'
    if start != None:
        fmt += 'chrom {args.start}'
    if end != None:
        fmt += 'chrom {args.end}'
    
    print
    
    if not os.path.isfile(args.bam+'.bai'):
        print 'no alignment file index found'
        print 'indexing alignment file'
        pysam.index(args.bam)
        print 'done indexing alignment file'
        print
        
    if not os.path.isfile(args.ref+'.fai'):
        print 'no fasta file index found'
        print 'indexing fasta file'
        pysam.faidx(args.ref)
        print 'done indexing fasta file'
        print
    
    ba = args.bam
    fa = args.ref
    #header='{chrom}\t{pos}\t{refbase}\t{maj_allele}\t{alt_allele}\t{alt_allele_freq}\t{pvalue}\t{prob_not_se}\t{prob_not_hetero}\t{n}\t{alt_count}\t{reverse}\t{forward}\t{mean_qual}\t{mean_mapq}\n'
    
    if rna_seq != '':
        rna_filter = True
        fmt +=  ' --rna_seq {args.rna_seq}'
        #header = '{chrom}\t{pos}\t{refbase}\t{maj_allele}\t{alt_allele}\t{alt_allele_freq}\t{pvalue}\t{prob_not_se}\t{prob_not_hetero}\t{n}\t{alt_count}\t{reverse}\t{forward}\t{mean_qual}\t{mean_mapq}\t{rna_seq}\n'
        if not os.path.isfile(args.rna_seq+'.bai'):
            print 'no RNA-seq file index found'
            print 'indexing RNA-seq file'
            pysam.index(args.rna_seq)
            print 'done indexing RNA-seq file'
            print
            
    if germ_pos != '':
        hap_filter = True
        fmt += ' --hapmap {args.germ_pos}'
    
    fmt += '\n'
    args_file.write(fmt.format(**locals()))
    args_file.close()
    print 'wrote arguments to', args_file_name
    print
    
    start = time.clock()
    results = write_mutations(out, ba, fa, rna_seq, germ_pos, min_reads, min_support,
                    min_af, min_pvalue, min_fr, rna_filter, hap_filter, min_qual,
                    min_se, min_hetero, ref_filter, min_mapq, dbsnp_bed_name,
                    copy_bed_name, binom, chrom, start, end, min_baseq, dist)
    out_file = open(out,'w')
    
    out_file.write('##fileformat=VCFv4.0\n')
    out_file.write('##source=SomVarIUS_V1.1\n')
    out_file.write('##INFO=<ID=NS,Number=1,Type=Integer,Description="Number of Samples With Data">\n')
    out_file.write('##INFO=<ID=DP,Number=1,Type=Integer,Description="Total Depth">\n')
    out_file.write('##INFO=<ID=AF,Number=.,Type=Float,Description="Allele Frequency">\n')
    out_file.write('##INFO=<ID=P,Number=.,Type=Float,Description="Combined Pvalue">\n')
    out_file.write('##INFO=<ID=SP,Number=.,Type=Float,Description="Sequencing Error Pvalue">\n')
    out_file.write('##INFO=<ID=GP,Number=.,Type=Float,Description="Germline Mutation Pvalue">\n')
    out_file.write('##INFO=<ID=MQ,Number=.,Type=Float,Description="Mean Mapping Quality">\n')
    out_file.write('##INFO=<ID=BQ,Number=.,Type=Float,Description="Mean Alternate Base Quality">\n')
    out_file.write('##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n')
    out_file.write('##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">\n')
    out_file.write('##FORMAT=<ID=DP,Number=1,Type=Integer,Description="Read Depth">\n')
    sample = args.bam.split('.bam')[0]
    out_file.write('#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t'+sample+'\n')
    #out_file.write(header.replace('}','').replace('{',''))
    for line in results:
        out_file.write(line)
    out_file.close()
    print time.clock() - start
    
if __name__=='__main__':
    '''
    main function
    takes arguments from the cammandline
    '''
    parser=argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
    parser_mut = subparsers.add_parser('call_mutations', help='flag to call mutations')
    parser_sort = subparsers.add_parser('sort', help='flag to sort bam file and bed file by name')
    parser_pickle = subparsers.add_parser('pickle', help='flag to store the pickled germline positions from bed')
    parser_clone = subparsers.add_parser('clones', help='flag to classify as clone or sub-clone')
    parser_query = subparsers.add_parser('query_mutations', help='flag to query given mutations in the bam')
    
    parser_mut.add_argument('--bam', help='input bam file', required=True)
    parser_mut.add_argument('--ref', help='reference fasta file', required=True)
    parser_mut.add_argument('--out', help='output file', required=True)
    parser_mut.add_argument('--rna_seq', help='RNA-seq bam file name', default='')
    parser_mut.add_argument('--germ_pos', help='pickled hapmap file', default='')
    parser_mut.add_argument('--dbsnp_bed', help='dbsnp bed file name', default='', required=True)
    parser_mut.add_argument('--copy_bed', help='copy number bed file name', default='')
    parser_mut.add_argument('--min_reads', help='minimum base coverage (default=10)',type=int, default=10)
    parser_mut.add_argument('--min_support', help='minimum number of reads supported alternate allele (default=4)',type=int, default=4)
    parser_mut.add_argument('--min_af', help='minimum allele frequency (default=0.05)',type=float, default=0.05)
    parser_mut.add_argument('--min_pvalue', help='minimum pvalue (default=0.0001)', type=float, default=0.0001)
    parser_mut.add_argument('--min_fr', help='minimum reverse/forward read ratio (default=0.05)',type=float, default=0.05)
    parser_mut.add_argument('--min_qual', help='minimum mean quality for alternate allele (default=25)',type=int, default=25)
    parser_mut.add_argument('--min_se', help='minimum probability not sequencing error (default=0.999)',type=float, default=0.999)
    parser_mut.add_argument('--min_hetero', help='minimum probability not germline (default=0.95)',type=float, default=0.95)
    parser_mut.add_argument('--ref_filter', help='flag to filter by reference fasta (default=False)',default=False, action='store_true')
    parser_mut.add_argument('--binom', help='flag to use binomial test instead of beta-binomial (default=False)',default=False, action='store_true')
    parser_mut.add_argument('--min_mapq', help='minimum mapping quality (default=55)',type=int, default=55)
    parser_mut.add_argument('--min_baseq', help='minimum base quality (default=13)',type=int, default=13)
    parser_mut.add_argument('--chrom', help='Chromosome name to look at', default=None)
    parser_mut.add_argument('--start', help='starting position', default=None)
    parser_mut.add_argument('--end', help='ending position', default=None)
    parser_mut.add_argument('--dist', help='write the beta binomial parameters to a file', default='')
    parser_mut.set_defaults(func=get_mutations)
    
    parser_sort.add_argument('--bam', help='input bam file', required=True)
    parser_sort.add_argument('--bam_out', help='name of sorted bam file', required=True)
    parser_sort.add_argument('--dbsnp', help='input dbsnp bed file', required=True)
    parser_sort.add_argument('--dbsnp_out', help='name of sorted dbsnp file', required=True)
    parser_sort.set_defaults(func=sort_files)
    
    parser_pickle.add_argument('--dbsnp', help='input dbsnp bed file', required=True)
    parser_pickle.add_argument('--dbsnp_out', help='name of pickled dbsnp file', required=True)
    parser_pickle.set_defaults(func=dbsnp_pickle)
    
    parser_clone.add_argument('--vcf', help='vcf file', required=True)
    parser_clone.add_argument('--t', help='tumor purity (default=1.0)', default=1.0, type=float)
    parser_clone.add_argument('--gmm', help='flag to classify by gaussian mixture model (default=False)', default=False, action='store_true')
    parser_clone.set_defaults(func=clones)
    
    parser_query.add_argument('--bam', help='input bam file', required=True)
    parser_query.add_argument('--out', help='output file', required=True)
    parser_query.add_argument('--muts', help='mutation bed file', required=True)
    parser_query.add_argument('--min_reads', help='minimum base coverage (default=10)',type=int, default=10)
    parser_query.add_argument('--min_support', help='minimum number of reads supported alternate allele (default=4)',type=int, default=4)
    parser_query.add_argument('--min_mapq', help='minimum mapping quality (default=55)',type=int, default=55)
    parser_query.add_argument('--min_baseq', help='minimum base quality (default=13)',type=int, default=13)
    parser_query.set_defaults(func=query_given_muts)
    
    args=parser.parse_args()
    
    args.func(args)
