#!/usr/bin/env python
# Create test set for use with pita
import os
import sys
import yaml
import argparse
import pybedtools
import subprocess as sp
from tempfile import NamedTemporaryFile

p = argparse.ArgumentParser()
p.add_argument("-c",
               dest= "configfile",
               required= True,
               help="Input configuration file"
              )
p.add_argument("-b",
               dest= "bedfile",
               required= True,
               help="BED file with regions",
              )
p.add_argument("-o",
               dest= "output",
               required= True,
               help="Output name"
              )

args = p.parse_args()
configfile = args.configfile
bedfile = args.bedfile
outname = args.output


# Parse YAML config file
f = open(configfile, "r")
config = yaml.load(f)

# Data directory
base = "."
if config.has_key("data_path"):
    base = config["data_path"]

# Output data directoy
if not os.path.exists(outname):
    os.mkdir(outname)
else:
    print("Output path {0} already exists!".format(outname))
    sys.exit(1)

# Get chromosomes from BED file
chroms = {}
a = pybedtools.BedTool(bedfile)
for f in a:
    chroms[f[0]] = 1
config["chromosomes"] = chroms.keys()


# Do all intersections with BED file
# Is more forgiving when repeats are not inserted in the yaml file:
def intersect(f):
    path = f['path']
    sys.stderr.write("Filtering {0}\n".format(path))
    if not os.path.exists(path):
        path = os.path.join(base, path)
    if not os.path.exists(path):
        print("File not found: {0}".format(path))
        sys.exit(1)
    a = pybedtools.BedTool(path)
    b = pybedtools.BedTool(bedfile)
    a.intersect(b, wa=True).saveas(os.path.join(outname, os.path.basename(path)))
    f['path'] = os.path.basename(path)

try:
    for f in config["annotation"] + config["repeats"]:
        intersect(f)
except:
    for f in config["annotation"]:
        intersect(f)
            


for f in config['data']:
    remove = []
    for rel_p in f['path']:
        p = rel_p
        if not os.path.exists(p):           
           p = os.path.join(base, p)
        out = os.path.join(outname, os.path.basename(p))
        sys.stderr.write("Filtering {0}\n".format(p)) 
        if p.endswith("bam"):
            tmp = NamedTemporaryFile(prefix="pita.flatbread.")
            cmd = "samtools view -H {0} > {1}".format(p, tmp.name)
            sp.call(cmd, shell=True) 
            for line in open(bedfile):
                vals = line.strip().split()
                loc = "{0}:{1}-{2}".format(vals[0], vals[1], vals[2])
                cmd = "samtools view {0} {1} >> {2}".format(p, loc, tmp.name)
                sp.call(cmd, shell=True)             
            cmd = "samtools view -S {0} | head -n 1".format(tmp.name)
            pop = sp.Popen(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
            stdout, stderr = pop.communicate()
            if stdout.strip():
                cmd = "samtools view -bS {0} > {1} && mv {1} {0}"
                sp.call(cmd.format(tmp.name, out), shell=True)
                cmd = "samtools sort {0} {1} && samtools index {2}".format(tmp.name, out.replace(".bam", ""), out)          
                sp.call(cmd, shell=True)
                tmp.close()
            else:
                sys.stderr.write("{} does not contain reads in these regions, leaving it out.\n".format(rel_p))

                remove.append(rel_p)
        else:
            a = pybedtools.BedTool(p)
            b = pybedtools.BedTool(bedfile)
            a.intersect(b, wa=True).saveas(out)
    corrPaths = [os.path.basename(x) for x in f['path']]    
    f['path'] = corrPaths
    for p in remove:
        f['path'].remove(p)

config["data_path"] = os.path.abspath(outname)

with open("{0}.yaml".format(outname), "w") as f:
    yaml.safe_dump(config, f, encoding="utf-8", allow_unicode=True)
