#!/home/jadrion/miniconda3/envs/deep/bin/python
"""
Reads a VCF file, estimates some simulation parameters, and simulates via msprime.
NOTE: This assumes that the user has previously QC'd and filtered the VCF.
"""

from ReLERNN.imports import *
from ReLERNN.helpers import *
from ReLERNN.manager import *
from ReLERNN.simulator import *


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-v','--vcf',dest='vcf',help='Filtered and QC-checked VCF file. Important: Every row must correspond to a biallelic SNP with no missing data!')
    parser.add_argument('-g','--genome',dest='genome',help='BED-formatted (i.e. zero-based) file corresponding to chromosomes and positions to evaluate')
    parser.add_argument('-m','--mask',dest='mask',help='BED-formatted file corresponding to inaccessible bases', default=None)
    parser.add_argument('-d','--projectDir',dest='outDir',help='Directory for all project output. NOTE: the same projectDir must be used for all functions of ReLERNN',default=None)
    parser.add_argument('-n','--demographicHistory',dest='dem',help='Output file from either stairwayplot, SMC++, or MSMC',default=None)
    parser.add_argument('-u','--assumedMu',dest='mu',help='Assumed per-base mutation rate',type=float,default=1e-8)
    parser.add_argument('-l','--assumedGenTime',dest='genTime',help='Assumed generation time (in years)',type=float)
    parser.add_argument('-r','--upperRhoThetaRatio',dest='upRTR',help='Assumed upper bound for the ratio of rho to theta',type=float,default=1.0)
    parser.add_argument('-t','--nCPU',dest='nCPU',help='Number of CPUs to use',type=int,default=1)
    parser.add_argument('--phased',help='Treat genotypes as phased',default=False, action='store_true')
    parser.add_argument('--unphased',dest='phased',help='Treat genotypes as unphased',action='store_false')
    parser.add_argument('--phaseError',dest='phaseError',help='Fraction of bases simulated with incorrect phasing',type=float,default=0.0)
    parser.add_argument('--maxSites',dest='winSizeMx',help='Max number of sites per window to train on. Important: too many sites causes problems in training (see README)!',type=int,default=1750)
    parser.add_argument('--forceWinSize',dest='forceWinSize',help='USED ONLY FOR TESTING, LEAVE AS DEFAULT',type=int,default=0)
    parser.add_argument('--maskThresh',dest='maskThresh',help='Discard windows where >= maskThresh percent of sites are inaccessible',type=float,default=1.0)
    parser.add_argument('--nTrain',dest='nTrain',help='Number of training examples to simulate',type=int,default=100000)
    parser.add_argument('--nVali',dest='nVali',help='Number of validation examples to simulate',type=int,default=1000)
    parser.add_argument('--nTest',dest='nTest',help='Number of test examples to simulate',type=int,default=1000)
    parser.add_argument('--nHotWins',dest='nHotWins',help='Number of total recombination windows per simulation (hotspot will only occur in one of windows, so nHowWins == 10 means 10% of the chromosome is a hotspot)',type=int,default=10)
    args = parser.parse_args()
    
    # Ensure all required arguments are provided
    if not args.vcf.endswith(".vcf"):
        print('Error: VCF file must end in extension ".vcf"')
        sys.exit(1)
    if not args.outDir:
        print("Warning: No project directory found, using current working directory.")
        projectDir = os.getcwd()
    else:
        projectDir = args.outDir
    if not args.mask:
        print("Warning: no accessibility mask found. All sites in the genome are assumed to be accessible.") 
    if args.dem:
        demHist = check_demHist(args.dem)
        if demHist == -9:
            print("Error: demographicHistory file must be raw output from either stairwayplot, SMC++, or MSMC")
            sys.exit(1)
        if not args.genTime:
            print("Error: assumed generation time must be supplied when simulating under stairwayplot, SMC++, or MSMC")
            sys.exit(1)
    else:
        print("Warning: no demographic history file found. All training data will be simulated under demographic equilibrium.")
        demHist = 0
    if not args.phased and args.phaseError != 0.0:
        print("Error: non-zero 'phaseError' cannot be used in conjunction with '--unphased'")
        sys.exit(1)
    

    ## Set up the directory structure to store the simulations data.
    time.sleep(5)
    nProc = args.nCPU
    trainDir = os.path.join(projectDir,"train_hotspot")
    valiDir = os.path.join(projectDir,"vali_hotspot")
    testDir = os.path.join(projectDir,"test_hotspot")
    networkDir = os.path.join(projectDir,"networks_hotspot")
    vcfDir = os.path.join(projectDir,"splitVCFs_hotspot")


    ## Make directories if they do not exist
    for p in [projectDir,trainDir,valiDir,testDir,networkDir,vcfDir]:
        if not os.path.exists(p):
            os.makedirs(p)

    
    ## Read the genome file
    chromosomes = []
    with open(args.genome, "r") as fIN:
        for line in fIN:
            ar = line.split()
            if len(ar)!=3:
                print("Error: genome file must be formatted as a bed file (i.e.'chromosome     start     end')")
                sys.exit(1)
            chromosomes.append("{}:{}-{}".format(ar[0],ar[1],ar[2]))
   

    ## Pass params to the vcf manager    
    manager_params = {
            'vcf':args.vcf,
            'mask':args.mask,
            'winSizeMx':args.winSizeMx,
            'forceWinSize':args.forceWinSize,
            'chromosomes':chromosomes,
            'vcfDir':vcfDir,
            'projectDir':projectDir,
            'networkDir':networkDir
              }
    vcf_manager = Manager(**manager_params)
    
    
    ## Split the VCF file
    vcf_manager.splitVCF(nProc=nProc)
    

    ## Calculate nSites per window
    wins, nSamps, maxS, maxLen = vcf_manager.countSites(nProc=nProc)


    ## Prepare the accessibility mask
    if args.mask:
        mask_fraction, win_masks = vcf_manager.maskWins(wins=wins, maxLen=maxLen, nProc=nProc)
    else:
        mask_fraction, win_masks = 0.0, None

    
    ## Define parameters for msprime simulation
    print("Simulating with window size = {} bp.".format(maxLen))
    a=0
    for i in range(nSamps-1):
        a+=1/(i+1)
    thetaW=maxS/a
    assumedMu = args.mu
    Ne=int(thetaW/(4.0 * assumedMu * ((1-mask_fraction) * maxLen)))
    rhoHi=assumedMu * args.upRTR
    if demHist:
        MspD = convert_demHist(args.dem, nSamps, args.genTime, demHist)
        dg_params = {
                'priorLowsRho':0.0,
                'priorHighsRho':rhoHi,
                'priorLowsMu':assumedMu * 0.66,
                'priorHighsMu':assumedMu * 1.33,
                'ChromosomeLength':maxLen,
                'winMasks':win_masks,
                'maskThresh':args.maskThresh,
                'phased':args.phased,
                'phaseError':args.phaseError,
                'MspDemographics':MspD,
                'hotspots':True,
                'nHotWins':args.nHotWins
                  }

    else:
        dg_params = {'N':nSamps,
            'Ne':Ne,
            'priorLowsRho':0.0,
            'priorHighsRho':rhoHi,
            'priorLowsMu':assumedMu * 0.66,
            'priorHighsMu':assumedMu * 1.33,
            'ChromosomeLength':maxLen,
            'winMasks':win_masks,
            'maskThresh':args.maskThresh,
            'phased':args.phased,
            'phaseError':args.phaseError,
            'hotspots':True,
            'nHotWins':args.nHotWins
                  }


    # Assign pars for each simulation
    dg_train = Simulator(**dg_params)
    dg_vali = Simulator(**dg_params)
    dg_test = Simulator(**dg_params)


    ## Dump simulation pars for use with parametric bootstrap
    simParsFILE=os.path.join(networkDir,"simPars.p")
    with open(simParsFILE, "wb") as fOUT:
        dg_params["bn"]=os.path.basename(args.vcf).replace(".vcf","")
        pickle.dump(dg_params,fOUT)


    ## Simulate data
    print("\nTraining set:")
    dg_train.simulateAndProduceTrees(numReps=args.nTrain,direc=trainDir,simulator="msprime",nProc=nProc)
    print("Validation set:")
    dg_vali.simulateAndProduceTrees(numReps=args.nVali,direc=valiDir,simulator="msprime",nProc=nProc)
    print("Test set:")
    dg_test.simulateAndProduceTrees(numReps=args.nTest,direc=testDir,simulator="msprime",nProc=nProc)
    print("\nSIMULATIONS FINISHED!\n")


    ## Count number of segregating sites in simulation
    SS=[]
    maxSegSites = 0
    minSegSites = float("inf")
    for ds in [trainDir,valiDir,testDir]:
        DsInfoDir = pickle.load(open(os.path.join(ds,"info.p"),"rb"))
        SS.extend(DsInfoDir["segSites"])
        segSitesInDs = max(DsInfoDir["segSites"])
        segSitesInDsMin = min(DsInfoDir["segSites"])
        maxSegSites = max(maxSegSites,segSitesInDs)
        minSegSites = min(minSegSites,segSitesInDsMin)


    ## Compare counts of segregating sites between simulations and input VCF
    print("SANITY CHECK")
    print("====================")
    print("numSegSites\t\t\tMin\tMean\tMax")
    print("Simulated:\t\t\t%s\t%s\t%s" %(minSegSites, int(sum(SS)/float(len(SS))), maxSegSites))
    for i in range(len(wins)):
        print("InputVCF %s:\t\t%s\t%s\t%s" %(wins[i][0],wins[i][3],wins[i][4],wins[i][5]))
    print("\n\n***ReLERNN_SIMULATE_HOTSPOT.py FINISHED!***\n")


if __name__ == "__main__":
	main()
