#!python
"""
Predicts the recombination rate for each genomic window in the VCF file
using a GRU network trained in ReLERNN_TRAIN.py
"""

from ReLERNN.imports import *
from ReLERNN.helpers import *
from ReLERNN.sequenceBatchGenerator import *


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-v','--vcf',dest='vcf',help='Filtered and QC-checked VCF file. Important: Every row must correspond to a biallelic SNP with no missing data!')
    parser.add_argument('-d','--projectDir',dest='outDir',help='Directory for all project output. NOTE: the same projectDir must be used for all functions of ReLERNN',default=None)
    parser.add_argument('--phased',help='VCF file is phased',default=False, action='store_true')
    parser.add_argument('--unphased',dest='phased',help='VCF file is unphased',action='store_false')
    parser.add_argument('--minSites',dest='minS',help='Minimum number of SNPs in a genomic window required to return a prediction', type=int, default = 50)
    parser.add_argument('--gpuID',dest='gpuID',help='Identifier specifying which GPU to use', type=int, default = 0)
    args = parser.parse_args()
    

    ## Set up the directory structure to store the simulations data.
    if not args.outDir:
        print("Warning: No project directory found, using current working directory.")
        projectDir = os.getcwd()
    else:
        projectDir = args.outDir
    trainDir = os.path.join(projectDir,"train")
    valiDir = os.path.join(projectDir,"vali")
    testDir = os.path.join(projectDir,"test")
    networkDir = os.path.join(projectDir,"networks")
    vcfDir = os.path.join(projectDir,"splitVCFs")
    modelSave = os.path.join(networkDir,"model.json")
    weightsSave = os.path.join(networkDir,"weights.h5")


    ## Read in the window sizes
    maxSimS = 0
    wins=[]
    winFILE=os.path.join(networkDir,"windowSizes.txt")
    with open(winFILE, "r") as fIN:
        for line in fIN:
            ar=line.split()
            wins.append([ar[0],int(ar[1]),int(ar[2]),int(ar[3]),int(ar[4]),int(ar[5])])
            maxSimS=max([maxSimS, int(ar[5])])


    ## Loop through chromosomes and predict
    pred_resultFiles = []
    for i in range(len(wins)):
        ## Read in the hdf5
        bn=os.path.basename(args.vcf)
        h5FILE=os.path.join(vcfDir,bn.replace(".vcf","_%s.hdf5" %(wins[i][0])))
        print("""Importing HDF5: "%s"...""" %(h5FILE))
        callset=h5py.File(h5FILE, mode="r")
        var=allel.VariantChunkedTable(callset["variants"],names=["CHROM","POS"], index="POS")
        chroms=var["CHROM"]
        pos=var["POS"]
        genos=allel.GenotypeChunkedArray(callset["calldata"]["GT"])


        ## Identify padding required
        maxSegSites = 0
        for ds in [trainDir,valiDir,testDir]:
            DsInfoDir = pickle.load(open(os.path.join(ds,"info.p"),"rb"))
            segSitesInDs = max(DsInfoDir["segSites"])
            maxSegSites = max(maxSegSites,segSitesInDs)
        maxSegSites = max(maxSegSites, maxSimS)


        ## Identify parameters used to train
        DsInfoDir = pickle.load(open(os.path.join(trainDir,"info.p"),"rb"))
        winSize=wins[i][2]
        batchSize=wins[i][4]


        ## Set network parameters
        bds_pred_params = {
            'INFO':DsInfoDir,
            'CHROM':chroms[0],
            'WIN':winSize,
            'IDs':get_index(pos,winSize),
            'GT':genos,
            'POS':pos,
            'batchSize': batchSize,
            'maxLen': maxSegSites,
            'frameWidth': 5,
            'sortInds':False,
            'center':False,
            'ancVal':-1,
            'padVal':0,
            'derVal':1,
            'realLinePos':True,
            'posPadVal':0,
            'phase':args.phased
                  }


        ### Define sequence batch generator
        pred_sequence = VCFBatchGenerator(**bds_pred_params)


        ## Load trained model and make predictions on VCF data
        pred_resultFile = os.path.join(projectDir,wins[i][0]+".CHPREDICT.txt")
        pred_resultFiles.append(pred_resultFile)
        load_and_predictVCF(VCFGenerator=pred_sequence,
                resultsFile=pred_resultFile,
                network=[modelSave,weightsSave],
                minS=args.minS,
                gpuID=args.gpuID)


    ## Combine chromosome predictions in whole genome prediction file and rm chromosome files
    genPredFILE=os.path.join(projectDir,bn.replace(".vcf",".PREDICT.txt"))
    ct=0
    with open(genPredFILE, "w") as fOUT:
        for f in pred_resultFiles:
            if ct==0:
                with open(f, "r") as fIN:
                    for line in fIN:
                        fOUT.write(line)
            else:
                with open(f, "r") as fIN:
                    fIN.readline()
                    for line in fIN:
                        fOUT.write(line)
            ct+=1
            cmd="rm %s" %(f)
            os.system(cmd)


    print("\n***ReLERNN_PREDICT.py FINISHED!***\n")


if __name__ == "__main__":
	main()
