#!python
"""Trains a network on data simulated by ReLERNN_SIMULATE_POOL.py"""

from ReLERNN.imports import *
from ReLERNN.helpers import *
from ReLERNN.sequenceBatchGenerator import *
from ReLERNN.networks import *


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-d','--projectDir',dest='outDir',help='Directory for all project output. NOTE: the same projectDir must be used for all functions of ReLERNN',default=None)
    parser.add_argument('--readDepth',dest='seqD',help='Mean read depth of the pool', type=int, default=0)
    parser.add_argument('--maf',dest='maf',help='discard simulated sites with allele frequencies < maf', type=float, default=0.05)
    parser.add_argument('--nEpochs',dest='nEpochs',help='Maximum number of epochs to train (EarlyStopping is implemented for validation accuracy)', type=int, default=1000)
    parser.add_argument('--nValSteps',dest='nValSteps',help='Number of validation steps', type=int, default=20)
    parser.add_argument('-t','--nCPU',dest='nCPU',help='Number of CPUs to use',type=int,default=1)
    parser.add_argument('--gpuID',dest='gpuID',help='Identifier specifying which GPU to use', type=int, default=0)
    args = parser.parse_args()
    
    
    print("Warning: training data to be treated as if generated by pool-seq")
    if args.seqD == 0:
        print("Error: assumed sequencing depth must be provided.")
        sys.exit(1)


    ## Set up the directory structure to store the simulations data.
    if not args.outDir:
        print("Warning: No project directory found, using current working directory.")
        projectDir = os.getcwd()
    else:
        projectDir = args.outDir
    trainDir = os.path.join(projectDir,"train")
    valiDir = os.path.join(projectDir,"vali")
    testDir = os.path.join(projectDir,"test")
    networkDir = os.path.join(projectDir,"networks")


    ## Define output files
    test_resultFile = os.path.join(networkDir,"testResults.p")
    test_resultFig = os.path.join(networkDir,"testResults.pdf")
    modelSave = os.path.join(networkDir,"model.json")
    weightsSave = os.path.join(networkDir,"weights.h5")


    ## Identify padding required
    maxSimS = 0
    winFILE=os.path.join(networkDir,"windowSizes.txt")
    with open(winFILE, "r") as fIN:
        for line in fIN:
            maxSimS=max([maxSimS, int(line.split()[5])])
    maxSegSites = 0
    for ds in [trainDir,valiDir,testDir]:
        DsInfoDir = pickle.load(open(os.path.join(ds,"info.p"),"rb"))
        segSitesInDs = max(DsInfoDir["segSites"])
        maxSegSites = max(maxSegSites,segSitesInDs)
    maxSegSites = max(maxSegSites, maxSimS)

    
    ## Set network parameters
    bds_train_params = {
        'treesDirectory':trainDir,
        'targetNormalization':"zscore",
        'batchSize': 64,
        'maxLen': maxSegSites,
        'frameWidth': 5,
        'shuffleInds':True,
        'sortInds':False,
        'center':False,
        'ancVal':-1,
        'padVal':0,
        'derVal':1,
        'realLinePos':True,
        'posPadVal':0,
        'seqD':args.seqD,
        'maf':args.maf
              }


    ## Dump batch pars for bootstrap
    batchParsFILE=os.path.join(networkDir,"batchPars.p")
    with open(batchParsFILE, "wb") as fOUT:
        pickle.dump(bds_train_params,fOUT)


    bds_vali_params = copy.deepcopy(bds_train_params)
    bds_vali_params['treesDirectory'] = valiDir
    bds_vali_params['batchSize'] = 64

    bds_test_params = copy.deepcopy(bds_train_params)
    bds_test_params['treesDirectory'] = testDir
    DsInfoDir = pickle.load(open(os.path.join(testDir,"info.p"),"rb"))
    bds_test_params['batchSize'] = DsInfoDir["numReps"]
    bds_test_params['shuffleExamples'] = False


    ## Define sequence batch generator
    train_sequence = SequenceBatchGenerator(**bds_train_params)
    vali_sequence = SequenceBatchGenerator(**bds_vali_params)
    test_sequence = SequenceBatchGenerator(**bds_test_params)


    ## Train network
    runModels(ModelFuncPointer=GRU_POOLED,
            ModelName="GRU_POOLED",
            TrainDir=trainDir,
            TrainGenerator=train_sequence,
            ValidationGenerator=vali_sequence,
            TestGenerator=test_sequence,
            resultsFile=test_resultFile,
            network=[modelSave,weightsSave],
            numEpochs=args.nEpochs,
            validationSteps=args.nValSteps,
            nCPU=args.nCPU,
            gpuID=args.gpuID)


    ## Plot results of predictions on test set
    plotResults(resultsFile=test_resultFile,saveas=test_resultFig)


    print("\n***ReLERNN_TRAIN_POOL.py FINISHED!***\n")

if __name__ == "__main__":
	main()
