#!/usr/bin/env python
# Devide data/annotation files in smaller subsets for use with pita
import os
import sys
import yaml
import subprocess as sp
from tempfile import NamedTemporaryFile as ntp
import argparse
import random
import pybedtools

p = argparse.ArgumentParser()
p.add_argument("-c",
               dest= "configfile",
               help="Input configuration file",
               required=True
              )
p.add_argument("-n",
               dest= "numchroms",
               type=int,
               help="Number of chromosomes / regions per file",
              )
p.add_argument("-b",
                dest="baseRegion",
                type=int,
                help="Number of basepairs / regions per file in megabases"
              )
p.add_argument("-s",
                dest="chrSizes",
                help="File with chromosome sizes (Ncbi format), only required in combination with -b flag"
              )
p.add_argument("-o",
               dest= "output",
               help="Output base name",
               required=True
              )

args = p.parse_args()
configfile = args.configfile
numchroms = args.numchroms
sizes = args.chrSizes
outname = args.output
baseRegion = args.baseRegion * 1000000
splitNum = 10

# Parse YAML config file
f = open(configfile, "r")
config = yaml.load(f)

# Output data directoy
if not os.path.exists(outname):
    os.mkdir(outname)
else:
    print("Output path {0} already exists!".format(outname))
    sys.exit(1)

# Data directory
base = "."
if config.has_key("data_path"):
    base = config["data_path"]


def parseSizeFile(sizes):
    dic = {}
    with open(sizes) as sizeFile:
        for i in sizeFile:
            line = i.strip().split("\t")
            dic[line[0]] = int(line[1])
    return dic

#Merges all annotation file so the regions can be determined better later
def readAnnotation(annotation):
    combinedFile = ntp(delete=True)
    mergedSorted = ntp(delete=False)
    for i in annotation:
        path = i["path"]
        kind = i["type"]
        if not os.path.exists(path):
            path = os.path.join(base,path)
        if not os.path.exists(path):
            print("File not found: {}".format(path))
            sys.exit(1)
        if "bed" in kind:
            sp.call("cat {} >> {}".format(path, combinedFile.name), shell=True)
    sortMerge = "bedtools sort -i {} | bedtools merge -i - > {}".format(combinedFile.name, mergedSorted.name)
    sp.call(sortMerge, shell=True)
    combinedFile.close()
    return mergedSorted.name
                                                                                            
#Compare the baseRegion with the actual annotation regions to make sure no intersesting regions are cutoff 
def intersect(msorted, region):
    out = ntp(delete=False)
    a = pybedtools.BedTool(msorted)
    b = pybedtools.BedTool(region)
    a.intersect(b, wa=True).saveas(out.name)
    return out.name


chroms = config["chromosomes"]
anno = config["annotation"]
msorted = readAnnotation(anno)


def ranges(dic):
    exactRanges = {}
    for i in chroms:
        chromRanges = []
        size = dic[i]
        baseRegion = size/splitNum
        for y in range(0, size, baseRegion):
            start = y
            stop = y + baseRegion
            if stop > size:
                stop = size
            chromRanges.append([i,start,stop])
        exactRanges[i] = chromRanges
    return exactRanges

if numchroms:
    random.shuffle(chroms)
    for i in range(0, len(chroms), numchroms):
        config['chromosomes'] = chroms[i: i + numchroms]
        with open("{0}/{0}_part_{1}.yaml".format(outname, (i / numchroms) + 1), "w") as f:
            f.write(yaml.dump(config))

elif baseRegion:
    dic = parseSizeFile(sizes)
    exactRanges = ranges(dic)
    print(exactRanges)
    count = 1

    for i in range(splitNum):
        tmpFile = ntp(delete=False)
        for chrom in chroms:
            r = exactRanges[chrom][i]
            tmpFile.write("{}\t{}\t{}\n".format(r[0], r[1], r[2])) # file with the raw baseRegion for intersection
        tmpFile.flush()
        #Getting the more sophisticated regions, based on the merged annotation
        correctRegions = intersect(msorted, tmpFile.name)
        tmpFile.close()
           
        cmd = "flatbread -c {0} -b {1} -o {2}/{2}_part_{3}".format(configfile, correctRegions, outname,str(count))
        sp.call(cmd, shell=True)
        count+=1 
