#!python
"""
Predicts the recombination rate for each genomic window in the POOL file
using the network trained by ReLERNN_TRAIN_POOL.py
"""

from ReLERNN.imports import *
from ReLERNN.helpers import *
from ReLERNN.sequenceBatchGenerator import *


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-p','--pool',dest='pool',help='Filtered and QC-checked pool file')
    parser.add_argument('-d','--projectDir',dest='outDir',help='Directory for all project output. NOTE: the same projectDir must be used for all functions of ReLERNN',default=None)
    parser.add_argument('--minSites',dest='minS',help='Minimum number of SNPs in a genomic window required to return a prediction', type=int, default = 50)
    parser.add_argument('--gpuID',dest='gpuID',help='Identifier specifying which GPU to use', type=int, default = 0)
    args = parser.parse_args()
    

    ## Set up the directory structure to store the simulations data.
    if not args.outDir:
        print("Warning: No project directory found, using current working directory.")
        projectDir = os.getcwd()
    else:
        projectDir = args.outDir
    trainDir = os.path.join(projectDir,"train")
    valiDir = os.path.join(projectDir,"vali")
    testDir = os.path.join(projectDir,"test")
    networkDir = os.path.join(projectDir,"networks")
    poolDir = os.path.join(projectDir,"splitPOOLs")
    modelSave = os.path.join(networkDir,"model.json")
    weightsSave = os.path.join(networkDir,"weights.h5")


    ## Read in the window sizes
    maxSimS = 0
    wins=[]
    winFILE=os.path.join(networkDir,"windowSizes.txt")
    with open(winFILE, "r") as fIN:
        for line in fIN:
            ar=line.split()
            wins.append([ar[0],int(ar[1]),int(ar[2]),int(ar[3]),int(ar[4]),int(ar[5])])
            maxSimS=max([maxSimS, int(ar[5])])


    ## Loop through chromosomes and predict
    for i in range(len(wins)):
        bn=os.path.basename(args.pool)
        poolFILE=os.path.join(poolDir,bn.replace(".pool","_%s.pool" %(wins[i][0])))
        print("""Importing POOL: "%s"...""" %(poolFILE))
        pos,fqs = [], []
        with open(poolFILE, "r") as fIN:
            for line in fIN:
                ar = line.split()
                pos.append(int(ar[1]))
                fqs.append(float(ar[2]))
            chrom = ar[0]
        pos = np.array(pos)
        fqs = np.array(fqs)


        ## Identify padding required
        maxSegSites = 0
        for ds in [trainDir,valiDir,testDir]:
            DsInfoDir = pickle.load(open(os.path.join(ds,"info.p"),"rb"))
            segSitesInDs = max(DsInfoDir["segSites"])
            maxSegSites = max(maxSegSites,segSitesInDs)
        maxSegSites = max(maxSegSites, maxSimS)


        ## Identify parameters used to train
        DsInfoDir = pickle.load(open(os.path.join(trainDir,"info.p"),"rb"))
        winSize=wins[i][2]
        batchSize=wins[i][4]

        batchPars = pickle.load(open(os.path.join(networkDir,"batchPars.p"),"rb"))
        normType = batchPars["targetNormalization"]
    

        ## Set network parameters
        bds_pred_params = {
            'INFO':DsInfoDir,
            'CHROM':chrom,
            'WIN':winSize,
            'IDs':get_index(pos,winSize),
            'GT':fqs,
            'POS':pos,
            'batchSize': batchSize,
            'maxLen': maxSegSites,
            'frameWidth': 5,
            'sortInds':False,
            'center':False,
            'ancVal':-1,
            'padVal':0,
            'derVal':1,
            'realLinePos':True,
            'posPadVal':0,
            'normType':normType
                  }


        ### Define sequence batch generator
        pred_sequence = POOLBatchGenerator(**bds_pred_params)


        ## Load trained model and make predictions on pool data
        pred_resultFile = os.path.join(projectDir,wins[i][0]+".CHPREDICT.txt")
        load_and_predictVCF(VCFGenerator=pred_sequence,
                resultsFile=pred_resultFile,
                network=[modelSave,weightsSave],
                minS=args.minS,
                gpuID=args.gpuID)


    ## Combine chromosome predictions in whole genome prediction file and rm chromosome files
    genPredFILE=os.path.join(projectDir,bn.replace(".pool",".PREDICT.txt"))
    files=[]
    for f in glob.glob(os.path.join(projectDir,"*.CHPREDICT.txt")):
        files.append(f)
    ct=0
    with open(genPredFILE, "w") as fOUT:
        for f in sorted(files):
            if ct==0:
                with open(f, "r") as fIN:
                    for line in fIN:
                        fOUT.write(line)
            else:
                with open(f, "r") as fIN:
                    fIN.readline()
                    for line in fIN:
                        fOUT.write(line)
            ct+=1
            cmd="rm %s" %(f)
            os.system(cmd)


    print("\n***ReLERNN_PREDICT_POOL.py FINISHED!***\n")


if __name__ == "__main__":
	main()
