#!python
"""
Performs a parametric bootstrap to assess any potential bias in recombination rate predictions.
Corrects for this bias and adds 95% confidence intevals to the predictions
"""


from ReLERNN.imports import *
from ReLERNN.helpers import *
from ReLERNN.simulator import *
from ReLERNN.sequenceBatchGenerator import *


def ParametricBootStrap(simParameters,
                        batchParameters,
                        trainDir,
                        network=None,
                        slices=1000,
                        repsPerSlice=1000,
                        gpuID=0,
                        tempDir="./Temp",
                        out="./ParametricBootstrap.p",
                        nCPU=1):


    '''
    This Function is for understanding network confidense
    over a range of rho, using a parametric bootstrap.

    SIDE NOTE: This will create a "temp" directory for filling
    writing and re-writing the test sets.
    after, it will destroy the tempDir.

    The basic idea being that we take a trained network,
    and iteritevly create test sets of simulation at steps which increase
    between fixed ranges of Rho.

    This function will output a pickle file containing
    a dictionary where the first

    This function will output a pickle file containing
    a dictionary where the ["rho"] key contains the slices
    between the values of rho where we simulate a test set,
    and test the trained model.

    The rest of the ket:value pairs in the dictionary contain
    the quartile information at each slice position for the
    distribution of test results
    '''

    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpuID)

    # load json and create model
    if(network != None):
        jsonFILE = open(network[0],"r")
        loadedModel = jsonFILE.read()
        jsonFILE.close()
        model=model_from_json(loadedModel)
        model.load_weights(network[1])
    else:
        print("Error: no pretrained network found!")

    if not os.path.exists(tempDir):
        os.makedirs(tempDir)

    priorLowsRho = simParameters['priorLowsRho']
    priorHighsRho = simParameters['priorHighsRho']

    rhoDiff = (priorHighsRho - priorLowsRho)/slices
    IQR = {"rho":[],"Min":[],"CI95LO":[],"Q1":[],"Q2":[],"Q3":[],"CI95HI":[],"Max":[]}
    rho = [(priorLowsRho+(rhoDiff*i)) for i in range(slices)]
    IQR["rho"] = rho

    mean,sd,pad = getMeanSDMax(trainDir)

    for idx,r in enumerate(rho):
        print("Simulating slice ",idx," out of ",slices)

        params = copy.deepcopy(simParameters)
        params["priorLowsRho"] = r
        params["priorHighsRho"] = r
        params.pop("bn", None)
        simulator = Simulator(**params)

        simulator.simulateAndProduceTrees(numReps=repsPerSlice,
                                            direc=tempDir,
                                            simulator="msprime",
                                            nProc=nCPU)

        batch_params = copy.deepcopy(batchParameters)
        batch_params['treesDirectory'] = tempDir
        batch_params['batchSize'] = repsPerSlice
        batch_params['shuffleExamples'] = False
        batchGenerator= SequenceBatchGenerator(**batch_params)

        x,y = batchGenerator.__getitem__(0)
        predictions = unNormalize(mean,sd,model.predict(x))
        predictions = [p[0] for p in predictions]

        minP,maxP = min(predictions),max(predictions)
        quartiles = np.percentile(predictions,[2.5,25,50,75,97.5])

        IQR["Min"].append(relu(minP))
        IQR["Max"].append(relu(maxP))
        IQR["CI95LO"].append(relu(quartiles[0]))
        IQR["Q1"].append(relu(quartiles[1]))
        IQR["Q2"].append(relu(quartiles[2]))
        IQR["Q3"].append(relu(quartiles[3]))
        IQR["CI95HI"].append(relu(quartiles[4]))

        del simulator
        del batchGenerator

    pickle.dump(IQR,open(out,"wb"))

    return rho,IQR


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-d','--projectDir',dest='outDir',help='Directory for all project output. NOTE: the same projectDir must be used for all functions of ReLERNN',default=None)
    parser.add_argument('-t','--nCPU',dest='nCPU',help='Number of CPUs to use',type=int,default=1)
    parser.add_argument('--gpuID',dest='gpuID',help='Identifier specifying which GPU to use', type=int, default=0)
    parser.add_argument('--nSlice',dest='nSlice',help='Number of recombination rate bins to simulate over', type=int, default=100)
    parser.add_argument('--nReps',dest='nReps',help='Number of simulations per step', type=int, default=1000)
    args = parser.parse_args()


    ## Set up the directory structure and output files
    if not args.outDir:
        print("Warning: No project directory found, using current working directory.")
        projectDir = os.getcwd()
    else:
        projectDir = args.outDir
    trainDir = os.path.join(projectDir,"train")
    valiDir = os.path.join(projectDir,"vali")
    testDir = os.path.join(projectDir,"test")
    networkDir = os.path.join(projectDir,"networks")
    bs_resultFile = os.path.join(networkDir,"bootstrapResults.p")
    bs_plotFile = os.path.join(networkDir,"bootstrapPlot.pdf")
    modelWeights = [os.path.join(networkDir,"model.json"),os.path.join(networkDir,"weights.h5")]
    bs_resultFile = os.path.join(networkDir,"bootstrapResults.p")
    bsDir = os.path.join(projectDir,"PBS")


    ## Load simulation and batch pars
    simParsFILE=os.path.join(networkDir,"simPars.p")
    batchParsFILE=os.path.join(networkDir,"batchPars.p")
    with open(simParsFILE, "rb") as fIN:
        simPars=pickle.load(fIN)
    with open(batchParsFILE, "rb") as fIN:
        batchPars=pickle.load(fIN)
    pred_resultFiles = []
    for f in glob.glob(os.path.join(projectDir,"*.PREDICT.txt")):
        pred_resultFiles.append(f)
    if len(pred_resultFiles) < 1:
        print("Error: no .PREDICT.txt file found. You must run ReLERNN_PREDICT.py prior to running ReLERNN_BSCORRECT.py")
        sys.exit(1)
    elif len(pred_resultFiles) > 1:
        print("Error: multiple prediction files found.")
        sys.exit(1)
    pred_resultFile = pred_resultFiles[0]


    ## Strap it on!
    ParametricBootStrap(
            simPars,
            batchPars,
            trainDir,
            network=modelWeights,
            slices=args.nSlice,
            repsPerSlice=args.nReps,
            gpuID=args.gpuID,
            out=bs_resultFile,
            tempDir=bsDir,
            nCPU=args.nCPU)


    ## Plot results from bootstrap
    plotParametricBootstrap(bs_resultFile,bs_plotFile)


    ## Load bootstrap values
    with open(bs_resultFile, "rb") as fIN:
        bs=pickle.load(fIN)


    ## Loop, correct, and write output
    correctedfile=pred_resultFile.replace(".txt", ".BSCORRECTED.txt")
    with open(correctedfile, "w") as fout, open(pred_resultFile, "r") as fin:
        for line in fin:
            if not line.startswith("chrom"):
                ar=line.split()
                rate=float(ar[4])
                C=get_corrected(rate,bs)
                ar[4]=C[0]
                ar.extend([C[1],C[2]])
                fout.write("\t".join([str(x) for x in ar])+"\n")
            else:
                #fout.write(line)
                fout.write("%s\t%s\t%s\t%s\t%s\t%s\t%s\n" %("chrom","start","end","nSites","recombRate","CI95LO","CI95HI"))


    ## Remove the bootstrap tree files
    shutil.rmtree(bsDir)
    print("\n***ReLERNN_BSCORRECT.py FINISHED!***\n")


if __name__ == "__main__":
	main()
