#!/home/taewing/miniconda3/bin/python

import os
import sys
import argparse
import gzip
import csv
import logging
import multiprocessing as mp
import matplotlib
import random
import sqlite3

# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')

# Illustrator compatibility
new_rc_params = {'text.usetex': False, "svg.fonttype": 'none'}
matplotlib.rcParams.update(new_rc_params)

import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.patches import ConnectionPatch
import seaborn as sns

import numpy as np
import pandas as pd
import pysam
import scipy.stats as ss

import skbio.alignment as skalign
import skbio.sequence as skseq

from uuid import uuid4
from collections import defaultdict as dd
from collections import Counter
from itertools import product
from operator import itemgetter
from copy import deepcopy

from ont_fast5_api.fast5_interface import get_fast5_file

FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Read:
    def __init__(self, read_name, cpg_loc, stat, methcall, modname, phase=None):
        self.read_name  = read_name
        self.llrs       = {}
        self.meth_calls = {}
        self.phase      = phase
        self.mod_names  = {}

        # used for locus per-read plot
        self.ypos   = None
        self.starts = []
        self.ends   = []

        self.add_mod(cpg_loc, stat, methcall, modname)

    def add_mod(self, cpg_loc, stat, methcall, modname):
        assert methcall in (-1,0,1)
        self.llrs[cpg_loc]       = stat
        self.meth_calls[cpg_loc] = methcall
        self.mod_names[cpg_loc]  = modname

    def overlap(self, other):
        return min(max(self.ends), max(other.ends)) - max(min(self.starts), min(other.starts)) > 0


class Gene:
    def __init__(self, ensg, name):
        self.ensg = ensg
        self.name = name
        self.tx_start = None
        self.tx_end = None
        self.cds_start = None
        self.cds_end = None
        self.exons = []

    def add_exon(self, block):
        assert len(block) == 2
        assert block[0] < block[1]
        self.exons.append(block)
        self.exons = sorted(self.exons, key=itemgetter(0))

    def add_tx(self, block):
        assert len(block) == 2
        assert block[0] < block[1]
        if self.tx_start is None or self.tx_start > block[0]:
            self.tx_start = block[0]

        if self.tx_end is None or self.tx_end < block[1]:
            self.tx_end = block[1]

    def add_cds(self, block):
        assert len(block) == 2
        if block[0] > block[1]:
            logger.warning('CDS block start > end in gene %s' % self.ensg)
            return None

        if self.cds_start is None or self.cds_start > block[0]:
            self.cds_start = block[0]

        if self.cds_end is None or self.cds_end < block[1]:
            self.cds_end = block[1]

    def has_tx(self):
        return None not in (self.tx_start, self.tx_end)

    def has_cds(self):
        return None not in (self.cds_start, self.cds_end)

    def merge_exons(self):
        new_exons = []
        if len(self.exons) == 0:
            return

        last_block = self.exons[0]

        for block in self.exons[1:]:
            if min(block[1], last_block[1]) - max(block[0], last_block[0]) > 0: # overlap
                last_block = [min(block[0], last_block[0]), max(block[1], last_block[1])]

            else:
                new_exons.append(last_block)

            last_block = block

        new_exons.append(last_block)

        self.exons = new_exons


RC = {'A':'T', 'T':'A', 'C':'G', 'G':'C', 'N':'N'}

class Motif:
    def __init__(self, motif):

        self.motif = motif

        self.left_bases  = ''
        self.right_bases = ''
        self.key_base    = ''
        self.full_seq    = ''

        assert len(motif.split('[')) == 2, 'bad motif syntax: %s' % motif
        assert len(motif.split(']')) == 2, 'bad motif syntax: %s' % motif

        self.left_base   = motif.split('[')[0]
        self.right_bases = motif.split(']')[1]
        self.key_base    = motif.split('[')[1].split(']')[0]
        self.full_seq    = self.left_bases + self.key_base + self.right_bases
        
        assert self.key_base in 'ATCG', 'bad motif syntax: %s' % motif


    def match_seq(self, seq):
        ''' returns 0-based coords of key base for matches in seq '''

        sites = dd(list) 

        seq = list(seq.upper())


        for i in range(len(seq)-len(self.full_seq)):
            nmer = ''.join(seq[i:i+len(self.full_seq)])
            if nmer == self.full_seq:
                sites[i + len(self.left_bases)].append(self.motif)

        return sites
        

class Mod:
    def __init__(self, read_id, read_pos, read_base, sites, mods):
        self.read_id     = read_id
        self.read_pos    = read_pos
        self.read_base   = read_base
        self.sites       = sites
        self.mods        = mods
        self.chrom       = None
        self.genome_pos  = None
        self.genome_base = None
        self.aln_strand  = None

    def __str__(self):
        return '\t'.join((self.read_id, self.chrom, str(self.genome_pos), self.genome_base, self.aln_strand, str(self.read_pos), self.read_base, ','.join(self.sites), '\t'.join(map(str, self.mods))))


def base_field(mod_metadata):
    long_names = mod_metadata['modified_base_long_names'].split()

    bases = []

    i = 0
    alpha_field_name = None

    if 'output_alphabet' in mod_metadata:
        alpha_field_name = 'output_alphabet'

    elif 'modified_base_alphabet' in mod_metadata:
        alpha_field_name = 'modified_base_alphabet'

    for base in list(mod_metadata[alpha_field_name]):
        if base not in 'ACTG':
            bases.append(long_names[i])
            i += 1
        else:
            bases.append(base)

    return bases


def pandas_df(out, mod_base, can_base):
    data = dd(dict)

    header = []

    for i, line in enumerate(out):
        if i == 0:
            header = line.strip().split()
            continue

        for j, val in enumerate(line.strip().split('\t')):
            data[i][header[j]] = val

    data = pd.DataFrame.from_dict(data).T
    data = plot_data = pd.DataFrame(data.to_dict())

    data['can_log_prob'] = np.log((pd.to_numeric(data[can_base])+1)/256)
    data['mod_log_prob'] = np.log((pd.to_numeric(data[mod_base])+1)/256)
    data['llr'] = data['mod_log_prob'] - data['can_log_prob']

    return data


def guppy_f5_fetch(fast5, bam_db, args):
    bam = pysam.AlignmentFile(args.bam)
    ref = pysam.Fastafile(args.ref)

    bases = []
    modbases = dd(list)

    m = Motif(args.motif)

    out = []

    with get_fast5_file(args.fast5 + '/' + fast5, mode="r") as f5:
        for read_id in f5.get_read_ids():
            read = f5.get_read(read_id)
            latest_basecall = read.get_latest_analysis('Basecall_1D')

            mod_base_table = read.get_analysis_dataset(latest_basecall, 'BaseCalled_template/ModBaseProbs')
            called_base_table = read.get_analysis_dataset(latest_basecall, 'BaseCalled_template/Fastq')

            mod_table_path = '{}/BaseCalled_template/ModBaseProbs'.format(latest_basecall)
            called_table_path = '{}/BaseCalled_template/Fastq'.format(latest_basecall)

            mod_metadata = read.get_analysis_attributes(mod_table_path)
            called_metadata = read.get_analysis_attributes(called_table_path)

            if None in (mod_metadata, called_metadata):
                logger.error('fast5 %s does not seem to contain basecalling information' % fast5)
                return None

            if len(bases) == 0:
                bases = base_field(mod_metadata)

                mod_names = [b for b in bases if b not in ('ATCG')]

                if args.modname not in mod_names:
                    logger.error('mod "%s" not found in metadata. Available mod names: %s' % (args.modname, ','.join(mod_names)))
                    return None

                header = "readname\tchrom\tpos\tgenome_base\tstrand\tread_pos\tread_base\tmotif\t" + '\t'.join(bases) + "\n"

            seq = called_base_table.split('\n')[1]

            assert len(seq) == len(mod_base_table)

            sites = m.match_seq(seq)

            for i, mods in enumerate(mod_base_table):
                if i in sites:
                    modbases[read_id].append(Mod(read_id, i, seq[i], sites[i], mods))

    logger.info('processed %d reads from %s' % (len(modbases), fast5))

    logger.info('parsing alignments from %s' % bam_db)


    out.append(header)

    conn = sqlite3.connect(bam_db)
    c = conn.cursor()

    for readname in modbases:
        for row in c.execute("SELECT sam FROM bam WHERE readname='%s'" % readname):
            read = pysam.AlignedSegment.fromstring(row[0], bam.header)

            pos_lookup = {}
            for ap in read.get_aligned_pairs():
                if None not in ap:
                    pos_lookup[ap[0]] = ap[1]

            for modbase in modbases[read.query_name]:
                modbase.aln_strand = '+'
                aligned_read_pos = modbase.read_pos

                if read.is_reverse:
                    modbase.aln_strand = '-'
                    aligned_read_pos = (len(read.seq) - modbase.read_pos)-1 # SAM format stores read on + strand

                if aligned_read_pos in pos_lookup:
                    modbase.chrom = read.reference_name
                    modbase.genome_pos = pos_lookup[aligned_read_pos]
                    modbase.genome_base = ref.fetch(modbase.chrom, modbase.genome_pos, modbase.genome_pos+1)

                    modbase.genome_base = modbase.genome_base.upper()

                    if args.include_unmatched:
                        out.append(str(modbase) + '\n')
                        
                    else:
                        if modbase.aln_strand == '+' and modbase.genome_base == modbase.read_base:
                            out.append(str(modbase) + '\n')

                        if modbase.aln_strand == '-' and RC[modbase.genome_base] == modbase.read_base:
                            out.append(str(modbase) + '\n')

    
    data = pandas_df(out, args.modname, m.key_base)
    out_fn = args.fast5 + '/' + '.'.join(fast5.split('.')[:-1]) + '.gupmod.tsv'
    logger.info('writing output from %s to %s' % (fast5, out_fn))
    data.to_csv(out_fn, sep='\t', index=False)

    return out_fn


def exclude_ambiguous_reads(fn, chrom, start, end, min_mapq=10):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        p = read.get_reference_positions()
        if p[0] < start or p[-1] > end:
            if read.mapq >= min_mapq:
                reads.append(read.query_name)

    return reads


def exclude_ambiguous_phased_reads(fn, chrom, start, end, min_mapq=10, tag_untagged=False, ignore_tags=False):
    reads = {}

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        p = read.get_reference_positions()
        if p[0] < start or p[-1] > end:
            if read.mapq >= min_mapq:
                phase = None

                if tag_untagged or ignore_tags:
                    phase = 'unphased'

                HP = None
                PS = None

                if not ignore_tags:
                    for tag in read.get_tags():
                        if tag[0] == 'HP':
                            HP = tag[1]
                        if tag[0] == 'PS':
                            PS = tag[1]

                if None not in (HP, PS):
                    phase = str(PS) + ':' + str(HP)

                reads[read.query_name] = phase

    return reads


def get_ambiguous_reads(fn, chrom, start, end, min_mapq=10, w=50):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        p = read.get_reference_positions()
        if read.mapq < min_mapq or (p[0] > start-w and p[-1] < end+w):
            reads.append(read.query_name)

    return reads


def get_reads(fn, chrom, start, end, min_mapq=10):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        if read.mapq >= min_mapq:
            reads.append(read.query_name)

    return reads


def get_phased_reads(fn, chrom, start, end, min_mapq=10, tag_untagged=False, ignore_tags=False):
    reads = {}

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        if read.mapq >= min_mapq:    
            phase = None

            if tag_untagged or ignore_tags:
                phase = 'unphased'

            HP = None
            PS = None

            if not ignore_tags:
                for tag in read.get_tags():
                    if tag[0] == 'HP':
                        HP = tag[1]
                    if tag[0] == 'PS':
                        PS = tag[1]

            if None not in (HP, PS):
                phase = str(PS) + ':' + str(HP)

            reads[read.query_name] = phase

    return reads


def single_seq_fa(fn):
    with open(fn, 'r') as fa:
        seq   = ''
        for line in fa:
            if line.startswith('>'):
                assert seq == '', 'input fa must have only one entry'
            else:
                seq = seq + line.strip()

    return seq


def rc(dna):
    ''' reverse complement '''
    complements = str.maketrans('acgtrymkbdhvACGTRYMKBDHV', 'tgcayrkmvhdbTGCAYRKMVHDB')
    return dna.translate(complements)[::-1]


def get_modnames(meth_db):
    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    mod_names = []

    for row in c.execute("SELECT DISTINCT mod FROM modnames"):
        mod_names.append(row[0])

    return mod_names


def get_segmeth_calls(bam_fn, mod_names, meth_db, chrom, seg_start, seg_end, seg_name, seg_strand):
    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    reads = []
    if args.excl_ambig:
        reads = exclude_ambiguous_reads(bam_fn, chrom, seg_start, seg_end)
    else:
        reads = get_reads(bam_fn, chrom, seg_start, seg_end)

    reads = list(set(reads))

    seg_reads = {}

    for index in reads:
        for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

            cg_chrom, cg_start, stat, methstate, modname = row

            if chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start > seg_end:
                continue

            cg_seg_start = cg_start - seg_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)


    seg_result = {}

    for modname in mod_names:
        seg_meth_calls = dd(int)

        for name, read in seg_reads.items():
            for loc, call in read.meth_calls.items():
                if read.mod_names[loc] == modname:
                    seg_meth_calls[call] += 1

        seg_result[modname] = seg_meth_calls


    return seg_result, (chrom, seg_start, seg_end, seg_name, seg_strand)



def get_meth_locus(args, bam, meth_db, mod):
    # used for locus plots
    # set up
    logger.info('fetching reads from %s with for mod %s' % (bam, mod))

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos = args.interval.split(':')
    elt_start, elt_end = map(int, pos.split('-'))

    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    # get list of relevant reads (exludes reads not anchored outside interval)
    reads = []
    if args.excl_ambig:
        reads = exclude_ambiguous_reads(bam, chrom, elt_start, elt_end)
    else:
        reads = get_reads(bam, chrom, elt_start, elt_end)
        
    reads = list(set(reads))

    if args.unambig_highlight and args.highlight:
        h_coords = []
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_coords += map(int, h.split('-'))

        h_coords.sort()

        h_start, h_end = h_coords[0], h_coords[-1]

        excl_reads = get_ambiguous_reads(bam, chrom, h_start, h_end)

        new_reads = []
        for read in reads:
            if read not in excl_reads:
                new_reads.append(read)

        reads = new_reads

    seg_reads = {}

    for index in reads:
        for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

            cg_chrom, cg_start, stat, methstate, modname = row

            if modname != mod:
                continue

            if chrom != cg_chrom:
                continue

            if cg_start < elt_start or cg_start > elt_end:
                continue

            cg_seg_start = cg_start - elt_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    return seg_reads


def get_meth_profile_composite(args, seg_chrom, seg_start, seg_end, seg_name, seg_strand):
    logger.info('profiling %s %s:%d-%d:%s' % (seg_name, seg_chrom, seg_start, seg_end, seg_strand))

    te_ref_seq = single_seq_fa(args.teref)
    ref = pysam.Fastafile(args.ref)

    conn = sqlite3.connect(args.meth)
    c = conn.cursor()

    reads = []
    if args.excl_ambig:
        reads = exclude_ambiguous_reads(args.bam, seg_chrom, seg_start, seg_end)
    else:
        reads = get_reads(args.bam, seg_chrom, seg_start, seg_end)

    reads = list(set(reads))

    seg_reads = {}

    for index in reads:
        for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

            cg_chrom, cg_start, stat, methstate, modname = row

            if seg_chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start > seg_end:
                continue

            cg_seg_start = cg_start - seg_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    meth_table = dd(dict)
    sample = '.'.join(args.bam.split('.')[:-1])

    for name, read in seg_reads.items():
        for loc in read.llrs.keys():
            uuid = str(uuid4())
            meth_table[uuid]['loc'] = loc
            meth_table[uuid]['llr'] = read.llrs[loc]
            meth_table[uuid]['read'] = name
            meth_table[uuid]['sample'] = sample
            meth_table[uuid]['call'] = read.meth_calls[loc]

    meth_table = pd.DataFrame.from_dict(meth_table).T
    meth_table['loc'] = pd.to_numeric(meth_table['loc'])
    meth_table['llr'] = pd.to_numeric(meth_table['llr'])

    meth_table['orig_loc'] = meth_table['loc']
    meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

    coord_to_cpg = {}
    cpg_to_coord = {}
    for orig_loc, new_loc in zip(meth_table['orig_loc'], meth_table['loc']):
        coord_to_cpg[orig_loc] = new_loc
        cpg_to_coord[new_loc]  = orig_loc

    windowed_methfrac, meth_n = slide_window(meth_table, sample, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))

    if len(windowed_methfrac) <= int(args.smoothwindowsize):
        logger.warning('too few sites after windowing: %s:%d-%d' % (seg_chrom, seg_start, seg_end))
        return [], []

    smoothed_methfrac = smooth(np.asarray(list(windowed_methfrac.values())), window_len=int(args.smoothwindowsize))

    coord_meth_pos = []

    cpg_meth_pos = list(windowed_methfrac.keys())

    for cpg in cpg_meth_pos:
        if seg_strand == '+':
            coord_meth_pos.append(cpg_to_coord[cpg])
        if seg_strand == '-':
            coord_meth_pos.append((seg_end-seg_start)-cpg_to_coord[cpg])

    # alignment to ref elt

    elt_seq = ref.fetch(seg_chrom, seg_start, seg_end)
    if seg_strand == '-':
        elt_seq = rc(elt_seq)

    te_ref_seq = te_ref_seq.upper()
    elt_seq = elt_seq.upper()

    s_ref = skseq.DNA(te_ref_seq)
    s_elt = skseq.DNA(elt_seq)

    aln_res = []

    try:
        if args.globalign:
            aln_res = skalign.global_pairwise_align_nucleotide(s_ref, s_elt)
        else:
            aln_res = skalign.local_pairwise_align_ssw(s_ref, s_elt)
    except IndexError: # scikit-bio throws this if no bases align  >:|
        logger.warning('no align on seg: %s:%d-%d' % (seg_chrom, seg_start, seg_end))
        return [], []
    
    coord_ref, coord_elt = aln_res[2]
    
    len_ref = coord_ref[1] - coord_ref[0]
    len_elt = coord_elt[1] - coord_elt[0]

    if len_ref / len(te_ref_seq) < float(args.lenfrac):
        logger.warning('ref align too short on seg: %s:%d-%d (%f)' % (seg_chrom, seg_start, seg_end, len_ref / len(te_ref_seq)))
        return [], []

    if len_elt / len(elt_seq) < float(args.lenfrac):
        logger.warning('elt align too short on seg: %s:%d-%d (%f)' % (seg_chrom, seg_start, seg_end, len_elt / len(elt_seq)))
        return [], []

    tab_msa = aln_res[0]

    elt_to_ref_coords = {}

    pos_ref = coord_ref[0]
    pos_elt = coord_elt[0]

    for pos in tab_msa.iter_positions():
        pos = list(pos)
        b_ref = pos[0]
        b_elt = pos[1]

        if '-' not in pos:
            elt_to_ref_coords[pos_elt] = pos_ref
            pos_ref += 1
            pos_elt += 1

        if b_elt == '-':
            pos_ref += 1

        if b_ref == '-':
            elt_to_ref_coords[pos_elt] = 'na'
            pos_elt += 1

    revised_coord_meth_pos = []
    meth_profile = []

    for pos, meth in zip(coord_meth_pos, smoothed_methfrac):
        if pos not in elt_to_ref_coords:
            continue

        revised_pos = elt_to_ref_coords[pos]

        if revised_pos != 'na':
            revised_coord_meth_pos.append(revised_pos)
            meth_profile.append(meth)


    return revised_coord_meth_pos, meth_profile


def get_meth_calls_wg(bam_fn, meth_fn, chrom, seg_start, seg_end, phased, mod):
    conn = sqlite3.connect(meth_fn)
    c = conn.cursor()

    reads = get_phased_reads(bam_fn, chrom, seg_start, seg_end, tag_untagged=phased)

    seg_methreads = {}

    for index in reads:
        for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

            cg_chrom, cg_start, stat, methstate, modname = row

            if mod is not None and mod != modname:
                continue

            if chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start > seg_end:
                continue

            cg_seg_start = cg_start - seg_start

            if index not in seg_methreads:
                seg_methreads[index] = Read(index, cg_seg_start, stat, methstate, modname, phase=reads[index])
            else:
                seg_methreads[index].add_mod(cg_seg_start, stat, methstate, modname)


    meth_data = {}
    meth_data[1] = dd(list)
    meth_data[2] = dd(list)

    meth_table = {}
    meth_table[1] = dd(dict)
    meth_table[2] = dd(dict)

    if phased:
        for name, read in seg_methreads.items():
            if read.phase is None:
                continue

            if read.phase == 'unphased':
                continue

            if read.phase.split(':')[1] not in ['1','2']:
                continue

            for loc in read.llrs.keys():
                phase = int(read.phase.split(':')[1])
                assert phase in (1,2)

                meth_data[phase][loc].append(read.meth_calls[loc])


        for phase in (1,2):
            for loc in meth_data[phase]:
                pos = loc+seg_start
                N = len([call for call in meth_data[phase][loc] if call != 0]) # call 0 == ambiguous
                X = len([call for call in meth_data[phase][loc] if call == 1]) # call 1 == methylated

                if N > 0:
                    meth_table[phase][pos]['chr'] = chrom
                    meth_table[phase][pos]['N'] = N 
                    meth_table[phase][pos]['X'] = X 

    else:
        # pass through as phase 1 if not analysing phases
        phase = 1

        for name, read in seg_methreads.items():
            for loc in read.llrs.keys():
                meth_data[phase][loc].append(read.meth_calls[loc])

        for loc in meth_data[phase]:
            pos = loc+seg_start
            N = len([call for call in meth_data[phase][loc] if call != 0])
            X = len([call for call in meth_data[phase][loc] if call == 1])

            if N > 0:
                meth_table[phase][pos]['chr'] = chrom
                meth_table[phase][pos]['N'] = N 
                meth_table[phase][pos]['X'] = X 

    return [meth_table[1], meth_table[2]]


def slide_window(meth_table, sample, width=20, slide=2):
    # used for locus plots, composite plots
    midpt_min = min(meth_table['loc'])
    midpt_max = max(meth_table['loc'])

    win_start = int(midpt_min - width/2)
    win_end = win_start + width

    meth_frac = {}
    meth_n = {}

    while int((win_start+win_end)/2) < midpt_max:
        win_start += slide
        win_end += slide

        meth_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == 1)])
        unmeth_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == -1)])

        midpt = int((win_start+win_end)/2)

        if meth_count + unmeth_count > 0:
            meth_frac[midpt] = meth_count/(meth_count+unmeth_count)
            meth_n[midpt] = meth_count+unmeth_count

    return meth_frac, meth_n


def slide_window_phased(meth_table, phase, width=20, slide=2):
    midpt_min = min(meth_table['loc'])
    midpt_max = max(meth_table['loc'])

    win_start = int(midpt_min - width/2)
    win_end = win_start + width

    meth_frac = {}
    meth_n = {}

    while int((win_start+win_end)/2) < midpt_max:
        win_start += slide
        win_end += slide

        meth_count = len(meth_table.loc[(meth_table['phase'] == phase) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == 1)])
        unmeth_count = len(meth_table.loc[(meth_table['phase'] == phase) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == -1)])

        midpt = int((win_start+win_end)/2)

        if meth_count + unmeth_count > 0:
            meth_frac[midpt] = meth_count/(meth_count+unmeth_count)
            meth_n[midpt] = meth_count+unmeth_count

    return meth_frac, meth_n


def smooth(x, window_len=8, window='hanning'):
    # used for locus plots
    ''' modified from scipy cookbook: https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html '''

    assert window_len % 2 == 0, '--smoothwindowsize must be an even number'
    assert x.ndim == 1
    assert x.size > window_len, 'fewer data points than window_len'

    if window_len<3:
        return x

    assert window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']

    s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]]
    
    if window == 'flat': #moving average
        w=np.ones(window_len,'d')
    else:
        w=eval('np.'+window+'(window_len)')

    y=np.convolve(w/w.sum(),s,mode='valid')

    return y[(int(window_len/2)-1):-(int(window_len/2))]


def mask_methfrac(data, cutoff=20):
    # used for locus plots
    data = np.asarray(data)
    data = data > int(cutoff)

    segs = []

    in_seg = False
    seg_start = 0

    for i in range(len(data)):
        if data[i]:
            if in_seg:
                segs.append(list(range(seg_start, i)))

            in_seg = False

        else:
            if not in_seg:
                seg_start = i

            in_seg = True

    if in_seg:
        segs.append(list(range(seg_start, len(data))))

    return segs


def build_genes(gtf, chrom, start, end):
    # used for locus plots
    genes = {}

    for line in gtf.fetch(chrom, start, end):

        chrom, source, feature, start, end, score, strand, frame, attribs = line.split('\t')

        block = [int(start), int(end)]

        attribs = attribs.strip()

        attr_dict = {}

        for attrib in attribs.split(';'):
            if attrib:
                key, val = attrib.strip().split()[:2]
                key = key.strip()
                val = val.strip().strip('"')
                attr_dict[key] = val

        if 'gene_id' not in attr_dict:
            continue

        if 'gene_name' not in attr_dict:
            continue

        ensg = attr_dict['gene_id']
        name = attr_dict['gene_name']

        if ensg not in genes:
            genes[ensg] = Gene(ensg, name)

        if feature == 'exon':
            genes[ensg].add_exon(block)

        if feature == 'CDS':
            genes[ensg].add_cds(block)

        if feature == 'transcript':
            genes[ensg].add_tx(block)

    return genes


## tools

def db_guppy(args):
    assert os.path.exists(args.fast5), 'path not found: %s' % args.fast5

    bam_db = args.samplename + '.bamcache.db'

    if os.path.exists(bam_db):
        logger.info('using existing bam cache db %s' % bam_db)


    else:
        conn = sqlite3.connect(bam_db)
        c = conn.cursor()

        logger.info('caching %s to %s' % (args.bam, bam_db))
        c.execute('''CREATE TABLE bam (readname text, sam text)''')
        c.execute('''CREATE INDEX read_index ON bam(readname)''')

        for bam_fn in args.bam.split(','):
            bam = pysam.AlignmentFile(bam_fn)

            commit_interval = 100000

            for i, read in enumerate(bam.fetch(), 1):
                if read.is_secondary or read.is_supplementary or read.is_duplicate:
                    continue

                read.query_qualities = None

                read_data = (read.query_name, read.to_string())
                c.execute("INSERT INTO bam VALUES ('%s', '%s')" % read_data)

                if i % commit_interval == 0:
                    logger.info('commiting %d records to %s...' % (commit_interval, bam_db))
                    conn.commit()

            logger.info('commiting remaining records to %s...' % bam_db)

        conn.commit()
        conn.close()

    fast5s = []

    for fn in os.listdir(args.fast5):
        if fn.endswith('.fast5'):
            fast5s.append(fn)

    logger.info('found %d fast5 files in %s' % (len(fast5s), args.fast5))

    logger.info('fetching base modifications...')

    pool = mp.Pool(processes=int(args.procs))

    results = []

    for fast5 in fast5s:
        res = pool.apply_async(guppy_f5_fetch, [fast5, bam_db, args])
        results.append(res)


    outfiles = []

    for res in results:
        outfiles.append(res.get())

    db_fn = args.samplename + '.guppy.db'

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    conn = sqlite3.connect(db_fn)
    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for fn in outfiles:
        if fn is None:
            continue

        with open(fn) as tsv:
            logger.info('loading %s into %s' % (fn, db_fn))
            csv_reader = csv.DictReader(tsv, delimiter='\t')

            for row in csv_reader:
                readname = row['readname']
                chrom    = row['chrom']
                pos      = int(row['pos'])
                strand   = row['strand']

                mod_prob = float(row['mod_log_prob'])
                can_prob = float(row['can_log_prob'])
                minprob  = float(args.minprob)
                mod_base = args.modname

                llr = float(row['mod_log_prob']) - float(row['can_log_prob'])

                # adjust position of - strand calls to match nanopolish scheme
                if strand == '-':
                    pos -= 1

                methcall = 0

                if mod_prob >= minprob:
                    methcall = 1

                if can_prob >= minprob:
                    methcall = -1

                assert (mod_prob + can_prob) < 1.01 # allow for some rounding error

                ins_data = (chrom, pos, strand, readname, llr, methcall, mod_base)
                c.execute("INSERT INTO methdata VALUES ('%s', %d, '%s', '%s', %.2f, %d, '%s')" % ins_data)

            c.execute("INSERT INTO modnames VALUES ('%s')" % args.modname)

            logger.info('commiting records from %s to %s' % (fn, db_fn))
            conn.commit()

    conn.close()        
    logger.info('finished.')


def db_megalodon(args):
    basename = '.'.join(args.methdata.split('.')[:-1])
    if basename == 'per_read_modified_base_calls' and not args.db:
        sys.exit('default megalodon filename (per_read_modified_base_calls) is not informative: please specify a database name with --db')    

    db_fn = basename + '.megalodon.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        csv_reader = csv.DictReader(methdata, delimiter='\t')

        progress_interval = 500000

        modnames = {}

        for i, row in enumerate(csv_reader):
            readname = row['read_id']
            chrom    = row['chrm']
            pos      = int(row['pos'])
            strand   = row['strand']

            mod_prob = float(row['mod_log_prob'])
            can_prob = float(row['can_log_prob'])
            minprob  = float(args.minprob)
            mod_base = row['mod_base']

            modnames[mod_base] = True

            llr = float(row['mod_log_prob']) - float(row['can_log_prob'])

            # adjust position of - strand calls to match nanopolish scheme
            if strand == '-':
                pos -= 1

            methcall = 0

            if np.exp(mod_prob) >= minprob:
                methcall = 1

            if np.exp(can_prob) >= minprob:
                methcall = -1

            assert (mod_prob + can_prob) < 1.01 # allow for some rounding or floating point error

            ins_data = (chrom, pos, strand, readname, llr, methcall, mod_base)
            c.execute("INSERT INTO methdata VALUES ('%s', %d, '%s', '%s', %.2f, %d, '%s')" % ins_data)

            if i % progress_interval == 0:
                logger.info('processed %d records from %s' % (i, tsv))

        for mod in modnames:
            c.execute("INSERT INTO modnames VALUES ('%s')" % mod)

        logger.info('commiting records from %s to %s' % (tsv, db_fn))
        conn.commit()

    conn.close()


def db_nanopolish(args):
    basename = '.'.join(args.methdata.split('.')[:-1])

    if basename.endswith('.tsv'):
        basename = '.'.join(basename.split('.')[:-1])

    db_fn = basename + '.nanopolish.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        csv_reader = csv.DictReader(methdata, delimiter='\t')

        progress_interval = 500000

        for i, row in enumerate(csv_reader, 1):
            r_start  = int(row['start'])
            llr      = float(row['log_lik_ratio'])
            seq      = row['sequence']
            mod_base = args.modname

            if args.scalegroup:
                llr = llr/float(row['num_motifs'])

            methcall = 0

            if llr > float(args.thresh):
                methcall = 1

            elif llr < float(args.thresh)*-1:
                methcall = -1

            # get per-CG position (nanopolish/calculate_methylation_frequency.py)
            cg_pos = seq.find("CG")
            first_cg_pos = cg_pos

            while cg_pos != -1:
                cg_start = r_start + cg_pos - first_cg_pos
                cg_pos = seq.find("CG", cg_pos + 1)

                ins_data = (row['chromosome'], cg_start, row['strand'], row['read_name'], llr, methcall, mod_base)
                c.execute("INSERT INTO methdata VALUES ('%s', %d, '%s', '%s', %.2f, %d, '%s')" % ins_data)

            if i % progress_interval == 0:
                logger.info('processed %d records from %s' % (i, tsv))

        c.execute("INSERT INTO modnames VALUES ('%s')" % args.modname)

        logger.info('commiting records from %s to %s' % (tsv, db_fn))

        conn.commit()

    conn.close()


def segmeth(args):
    '''
    segment methylation stats over genomic intervals
    '''

    stats = [
    'meth_calls',
    'unmeth_calls',
    'no_calls',
    'methfrac'
    ]

    mod_names = []

    data = {}

    with open(args.data) as _:
        for line in _:
            bam, meth = line.strip().split()
            data[bam] = meth

            for m in get_modnames(meth):
                mod_names.append(m)

    mod_names = list(set(mod_names))

    logger.info('found mod names: %s' % ','.join(mod_names))

    base_names = ['.'.join(bam.split('.')[:-1]) for bam in data]

    data_basename = '.'.join(args.data.split('.')[:-1])
    ivl_basename = '.'.join(args.intervals.split('.')[:-1])
    outfn = '.'.join((ivl_basename, data_basename, 'segmeth.tsv'))

    if args.excl_ambig:
        outfn = '.'.join((ivl_basename, data_basename, 'excl_ambig', 'segmeth.tsv'))


    out = open(outfn, 'w')

    logger.info('segmeth output filename: %s' % outfn)

    pool = mp.Pool(processes=int(args.procs))

    results = []

    for bam_fn, meth_fn in data.items():
        base_name = '.'.join(bam_fn.split('.')[:-1])

        with open(args.intervals) as _:
            for line in _:
                c = line.strip().split()
                chrom, seg_start, seg_end = c[:3]

                seg_name = 'NA'
                seg_strand = 'NA'

                if len(c) > 3:
                    seg_name = c[3]

                if len(c) > 4:
                    seg_strand = c[4]

                seg_start = int(seg_start)
                seg_end = int(seg_end)

                res = pool.apply_async(get_segmeth_calls, [bam_fn, mod_names, meth_fn, chrom, seg_start, seg_end, seg_name, seg_strand])

                results.append((res, base_name))


    meth_segs = dd(dict)

    for res, base_name in results:
        meth_result, seg = res.get()

        if meth_result is None:
            continue

        seg_id = '%s:%d-%d' % seg[:3]

        seg_chrom, seg_start, seg_end, seg_name, seg_strand = map(str, seg)

        meth_segs[seg_id]['seg_id']     = seg_id
        meth_segs[seg_id]['seg_chrom']  = seg_chrom
        meth_segs[seg_id]['seg_start']  = seg_start
        meth_segs[seg_id]['seg_end']    = seg_end
        meth_segs[seg_id]['seg_name']   = seg_name
        meth_segs[seg_id]['seg_strand'] = seg_strand

        if seg_name == 'NA':
            meth_segs[seg_id]['seg_name'] = 'NoName'

        if seg_strand == 'NA':
            meth_segs[seg_id]['seg_strand'] = '.'

        for modname, meth_data in meth_result.items():
            no_calls = 0
            meth_calls = 0
            unmeth_calls = 0

            if -1 in meth_data:
                unmeth_calls = meth_data[-1]

            if 0 in meth_data:
                no_calls = meth_data[0]

            if 1 in meth_data:
                meth_calls = meth_data[1]

            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_meth_calls'] = meth_calls
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_unmeth_calls'] = unmeth_calls
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_no_calls'] = no_calls

            if meth_calls+unmeth_calls == 0:
                 meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_methfrac'] = 'NaN'
            else:
                meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_methfrac'] = meth_calls/float(meth_calls+unmeth_calls)


    header = ['seg_id', 'seg_chrom', 'seg_start', 'seg_end', 'seg_name', 'seg_strand']

    for bn in base_names:
        for m in mod_names:
            for s in stats:
                header.append('%s_%s_%s' % (os.path.basename(bn), m, s))

    out.write('\t'.join(header) + '\n')

    for mseg in meth_segs:
        output = [] 
        for h in header:
            if h not in meth_segs[mseg]:
                logger.warning('no calls for sample %s in segment %s, skipped.' % (h, mseg))
                break

            output.append(str(meth_segs[mseg][h]))
        out.write('\t'.join(output) + '\n')

    out.close()


def segplot(args):
    '''
    strip plots / violin plots of segmented methylation data over categories by sample
    '''

    data = pd.read_csv(args.segmeth, sep='\t', header=0, index_col=0)

    samples = []

    for c in data.columns:
        if c.endswith('_meth_calls'):
            samples.append(c.replace('_meth_calls', ''))

    logger.info('available samples: %s' % ','.join(samples))

    if args.samples:
        user_samples = []
        for s in args.samples.split(','):
            assert s in samples, '%s not in sample list!' % s
            user_samples.append(s)

        samples = user_samples

    for s in samples:
        data[s + '_methfrac'] = data[s+'_meth_calls']/(data[s+'_meth_calls']+data[s+'_unmeth_calls'])

    useable = []

    for seg in data.index:
        use_seg = True

        for s in samples:
            if (data[s+'_meth_calls'].loc[seg] + data[s+'_unmeth_calls'].loc[seg]) < int(args.mincalls):
                use_seg = False
                continue

        if use_seg:
            useable.append(seg)

    data = data.loc[useable]

    logger.info('useable sites: %d' % len(useable))

    plot_data = dd(dict)

    order = []

    for seg in data.index:
        for s in samples:
            uid = seg + ':' + s
            plot_data[uid]['sample'] = s
            plot_data[uid]['mCpG']   = data[s+'_methfrac'].loc[seg]
            plot_data[uid]['group']  = data['seg_name'].loc[seg]

            if plot_data[uid]['group'] not in order:
                order.append(plot_data[uid]['group'])

    plot_data = pd.DataFrame.from_dict(plot_data).T
    plot_data = pd.DataFrame(plot_data.to_dict())

    basename = '.'.join(args.segmeth.split('.')[:-1])

    plot_data.to_csv(basename+'.segplot_data.csv')
    logger.info('plot data written to %s.segplot_data.csv' % basename)

    pt_sz = int(args.pointsize)

    if args.categories is not None:
        order = args.categories.split(',')

    if args.violin:
        sns_plot = sns.violinplot(x='group', y='mCpG', data=plot_data, hue='sample', dodge=True, jitter=True, order=order, hue_order=samples)

    else:
        sns_plot = sns.stripplot(x='group', y='mCpG', data=plot_data, hue='sample', dodge=True, jitter=True, size=pt_sz, order=order, hue_order=samples)

    if args.tiltlabel:
        sns_plot.set_xticklabels(sns_plot.get_xticklabels(), rotation=45)

    sns_plot.set_ylim(float(args.ymin),float(args.ymax))

    fig = sns_plot.figure
    fig.set_size_inches(int(args.width), int(args.height))

    if args.svg:
        fig.savefig(basename+'.segplot.svg', bbox_inches='tight')
        logger.info('plot saved to %s.segplot.svg' % basename)
    else:
        fig.savefig(basename+'.segplot.png', bbox_inches='tight')
        logger.info('plot saved to %s.segplot.png' % basename)


def locus(args):
    '''
    plot methylation profile of a region in CpG space
    '''
    # set up

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos= args.interval.split(':')
    elt_start, elt_end = map(int, pos.split('-'))
    
    # highlights

    h_start = []
    h_end = []
    h_cpg_start = []
    h_cpg_end = []
    h_colors = []

    if args.highlight:
        h_colors = sns.color_palette(args.highlightpalette, n_colors=len(args.highlight.split(',')))

    data = {}

    with open(args.data) as _:
        for line in _:
            bam, meth_db = line.strip().split()
            data[bam] = meth_db

    # table for plotting

    meth_table = dd(dict)

    sample_order = []

    reads = {}
    orig_bam = {}

    use_mods = None
    if args.mods:
        use_mods = args.mods.split(',')

    for bam, meth_db in data.items():
        mods = sorted(get_modnames(meth_db))
        logger.info('found mods: %s in db %s' % (','.join(mods), meth_db))

        for mod in mods:
            if use_mods:
                if mod not in use_mods:
                    logger.info('skipping %s, not specified in -m/--mods %s' % (mod, args.mods))
                    continue

            bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + mod
            orig_bam[bamname] = bam
            reads[bamname] = get_meth_locus(args, bam, meth_db, mod)

            for name, read in reads[bamname].items():
                for loc in read.llrs.keys():
                    uuid = str(uuid4())
                    meth_table[uuid]['loc'] = loc
                    meth_table[uuid]['llr'] = read.llrs[loc]
                    meth_table[uuid]['read'] = name
                    meth_table[uuid]['sample'] = bamname
                    meth_table[uuid]['call'] = read.meth_calls[loc]

                    if bamname not in sample_order:
                        sample_order.append(bamname)

    meth_table = pd.DataFrame.from_dict(meth_table).T
    meth_table['loc'] = pd.to_numeric(meth_table['loc'])
    meth_table['llr'] = pd.to_numeric(meth_table['llr'])

    meth_table['orig_loc'] = meth_table['loc']
    meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

    coord_to_cpg = {}
    for orig_loc, new_loc in zip(meth_table['orig_loc'], meth_table['loc']):
        coord_to_cpg[orig_loc] = new_loc

    if args.highlight:
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_s, h_e = map(int, h.split('-'))
            h_start.append(h_s)
            h_end.append(h_e)

            h_start[-1] -= elt_start
            h_end[-1] -= elt_start

            h_cpg_start.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_start[-1]))])
            h_cpg_end.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_end[-1]))])


    fig = plt.figure()
    gs = gridspec.GridSpec(5,1,height_ratios=[1,5,1,3,3])

    if args.panelratios:
        p_ratios = list(map(int, args.panelratios.split(',')))
        gs = gridspec.GridSpec(5,1,height_ratios=p_ratios)

    sample_color = {}
    for i, sample in enumerate(sample_order):
        sample_color[sample] = sns.color_palette(args.samplepalette, n_colors=len(sample_order))[i]

    # plot genes

    ax0 = plt.subplot(gs[0])

    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.xaxis.set_ticks_position('top')

    gtf = None
    genes = [] 

    if args.gtf is not None:
        logger.info('building genes plot...')
        gtf = pysam.Tabixfile(args.gtf)
        genes = build_genes(gtf, chrom, elt_start, elt_end)

    exon_patches = []
    tx_lines = []

    genes_of_interest = []

    if args.genes is not None:
        genes_of_interest = args.genes.strip().split(',')

    i = 0
    for ensg in genes:
        if genes_of_interest:
            if genes[ensg].name not in genes_of_interest:
                continue

        if genes[ensg].has_tx():

            tx_lines.append(matplotlib.lines.Line2D([genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start], [0.4+i, 0.4+i], zorder=1))

            print('transcript: %d-%d %s' % (genes[ensg].tx_start, genes[ensg].tx_end, genes[ensg].name))

        genes[ensg].merge_exons()
        for exon_start, exon_end in genes[ensg].exons:
            exon_len = exon_end - exon_start

            exon_patches.append(matplotlib.patches.Rectangle([exon_start-elt_start, i], exon_len, 0.8, edgecolor='#777777', facecolor='#ff4500', zorder=2))


        blocks_str = ','.join(['%d-%d' % (s,e) for s, e in genes[ensg].exons])
        print('%s exons: %s' % (genes[ensg].name, blocks_str))

        i += 1

    if i < 3:
        i = 3

    ax0.set_ylim(0,i)
    ax0.set_yticks([])

    for p in exon_patches:
        ax0.add_patch(p)

    for tx in tx_lines:
        ax0.add_line(tx)


    # per-read plot

    logger.info('building read alignment plot...')
    ax1 = plt.subplot(gs[1])
    ax1.set_xticks([])
    ax1.set_yticks([])

    readstack = dd(list)

    max_y  = 1
    pack_y = 1

    for bamname in reads:
        fetch_reads_bam = pysam.AlignmentFile(orig_bam[bamname])
        pos_cache = {}

        for read in fetch_reads_bam.fetch(chrom, elt_start, elt_end):
            if read.mapq < 10:
                continue

            if read.query_name not in pos_cache:
                pos_cache[read.query_name] = ([read.reference_start], [read.reference_end])
            else:
                pos_cache[read.query_name][0].append(read.reference_start)
                pos_cache[read.query_name][1].append(read.reference_end)

        for readname, read in reads[bamname].items():
            read.ypos = max_y
            max_y += 1

            read.starts, read.ends = pos_cache[readname]
            readstack[bamname].append(read)

        fetch_reads_bam.close()

    # read packing

    for bamname in readstack:
        reads = readstack[bamname]

        y = dd(list)

        for read in readstack[bamname]:
            y[read.ypos].append(read)

        for p in y:
            for q in y:
                if p == q:
                    continue

                for read_q in y[q]:
                        move = True

                        for read_p in y[p]:
                            if read_q.overlap(read_p):
                                move = False

                        if move:
                            y[p].append(read_q)
                            y[q].remove(read_q)

        for p in y:
            if len(y[p]) > 0:
                for read in sorted(y[p], key=lambda r: min(r.starts)):
                    read.ypos = pack_y
                pack_y += 1

    ax1.set_ylim(0,pack_y+1)

    for bamname in readstack:
        for read in readstack[bamname]:
            for i in range(len(read.starts)):
                readline_start = max(read.starts[i], elt_start) - elt_start
                readline_end   = min(read.ends[i], elt_end) - elt_start

                for call_pos, call in read.meth_calls.items():
                    if call == -1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec=sample_color[bamname], mfc='white', markersize=2, zorder=2)

                    if call == 1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec='black', mfc='black', markersize=2, zorder=2)


                ax1.add_line(matplotlib.lines.Line2D([readline_start, readline_end], [read.ypos-0.1, read.ypos-0.1], zorder=1, color=sample_color[bamname], alpha=0.4))


    # plot correspondence between genome space and cpg space

    logger.info('building mod-space plot...')
    ax2 = plt.subplot(gs[2])
    ax3 = ax2.twiny()

    ax2.set_ylim(0,10)
    ax2.set_yticklabels([])

    x1 = []
    x2 = []

    step = int(args.cpgspace)

    for i, x in enumerate(meth_table['orig_loc']):
        if i in (0, len(meth_table['orig_loc'])-1):
            x2.append(x)
            x1.append(coord_to_cpg[x])

        elif i % step == 0:
            x2.append(x)
            x1.append(coord_to_cpg[x])

    
    ax2.vlines(x1, 0, 1, color='#777777', zorder=1)
    ax3.vlines(x2, 9, 10, color='#777777', zorder=1)

    if args.highlight:
        for i in range(len(h_start)):
            orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), h_end[i]-h_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=2)
            cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), h_cpg_end[i]-h_cpg_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=3)

            ax3.add_patch(orig_highlight_box)
            ax2.add_patch(cpg_highlight_box)

    for x1_x, x2_x in zip(x1, x2):
        link_end1 = (x1_x, 1)
        link_end2 = (x2_x, 9)

        l_col = '#777777'

        for i in range(len(h_start)):
            if x2_x >= h_start[i] and x2_x <= h_end[i]:
                l_col = h_colors[i]

        con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=l_col)
        ax3.add_artist(con)

    ax0.set_xlim(ax3.get_xlim()) # sync axes between orig coords and gtf plot
    ax1.set_xlim(ax3.get_xlim())
    ax3.set_xticks([])

    n_ticks = 10
    tick_interval = (elt_end-elt_start)/n_ticks
    tick_list = list(range(int(ax0.get_xlim()[0]), int(ax0.get_xlim()[1]), int(tick_interval)))

    xt_labels = [str(int(t+elt_start)) for t in tick_list]
    ax0.set_xticks(tick_list) 
    ax0.set_xticklabels(xt_labels)

    # llr plot

    logger.info('building llr plot...')
    ax4 = plt.subplot(gs[3])

    ax4.axhline(y=2.5, c='k', linestyle='--',lw=1)
    ax4.axhline(y=0, c='#bbbbbb', linestyle='--',lw=1)
    ax4.axhline(y=-2.5, c='k', linestyle='--',lw=1)

    ax4 = sns.lineplot(x='loc', y='llr', hue='sample', data=meth_table, palette=args.samplepalette)
    ax4.set_xlim(ax2.get_xlim())

    # meth frac plot

    logger.info('building meth frac plot...')
    ax5 = plt.subplot(gs[4])

    order_stack = 1

    for sample in sample_order:
        windowed_methfrac, meth_n = slide_window(meth_table, sample, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))

        smoothed_methfrac = smooth(np.asarray(list(windowed_methfrac.values())), window_len=int(args.smoothwindowsize))

        masked_segs = mask_methfrac(list(meth_n.values()), cutoff=args.maskcutoff)

        ax5.plot(list(windowed_methfrac.keys()), smoothed_methfrac, marker='', lw=4, color=sample_color[sample])

        order_stack += 1

        for seg in masked_segs:
            if len(seg) > 2:
                mf_seg = np.asarray(smoothed_methfrac)[seg]
                pos_seg = np.asarray(list(windowed_methfrac.keys()))[seg]
            
                ax5.plot(pos_seg, mf_seg, marker='', lw=4, color='#ffffff', alpha=0.8, zorder=order_stack)

                order_stack += 1

    ax5.set_xlim(ax2.get_xlim())
    ax5.set_ylim((-0.05,1.05))

    fn_prefix = '.'.join(args.data.split('.')[:-1]) + '.' + '_'.join(args.interval.split(':')[:2]) + '.' + ''.join(use_mods)

    if args.genes is not None:
        fn_prefix = '_'.join(args.genes.split(',')) + '.' + fn_prefix


    fig.set_size_inches(16, 8)
    if args.svg:
        plt.savefig('%s.locus.meth.svg' % fn_prefix, bbox_inches='tight')
        logger.info('plot saved to %s.locus.meth.svg' % fn_prefix)

    else:
        plt.savefig('%s.locus.meth.png' % fn_prefix, bbox_inches='tight')
        logger.info('plot saved to %s.locus.meth.png' % fn_prefix)


def haplocus(args):
    '''
    methylation plots from haplotagged .bams
    '''

    # set up
    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos = args.interval.split(':')
    elt_start, elt_end = map(int, pos.split('-'))

    bamname = '.'.join(os.path.basename(args.bam).split('.')[:-1])
    fn_prefix = os.path.basename(bamname) + '.' + '_'.join(args.interval.split(':')[:2])
    fn_prefix += '.' + ''.join(args.mods.split(','))

    # highlight
    h_start = []
    h_end = []
    h_cpg_start = []
    h_cpg_end = []
    h_colors = []

    if args.highlight:
        h_colors = sns.color_palette(args.highlightpalette, n_colors=len(args.highlight.split(',')))

    conn = sqlite3.connect(args.methdata)
    c = conn.cursor()

    # get list of relevant reads (exludes reads not anchored outside interval)
    reads = {}
    if args.excl_ambig:
        reads = exclude_ambiguous_phased_reads(args.bam, chrom, elt_start, elt_end, tag_untagged=args.tag_untagged, ignore_tags=args.ignore_tags)
    else:
        reads = get_phased_reads(args.bam, chrom, elt_start, elt_end, tag_untagged=args.tag_untagged, ignore_tags=args.ignore_tags)

    readnames = list(set(reads.keys()))

    if args.unambig_highlight and args.highlight:
        h_coords = []
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_coords += map(int, h.split('-'))

        h_coords.sort()

        h_min, h_max = h_coords[0], h_coords[-1]

        excl_reads = get_ambiguous_reads(args.bam, chrom, h_min, h_max)

        new_reads = []
        for read in readnames:
            if read not in excl_reads:
                new_reads.append(read)

        readnames = new_reads

    methreads = {}

    for index in readnames:
        for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

            cg_chrom, cg_start, stat, methstate, modname = row

            if chrom != cg_chrom:
                continue

            if cg_start < elt_start or cg_start > elt_end:
                continue

            cg_seg_start = cg_start - elt_start

            if index not in methreads:
                methreads[index] = Read(index, cg_seg_start, stat, methstate, modname, phase=reads[index])
            else:
                methreads[index].add_mod(cg_seg_start, stat, methstate, modname)


    # table for plotting
    meth_table = dd(dict)

    phase_order = []
    orig_phase = {}

    for name, read in methreads.items():
        if read.phase is None:
            continue

        for loc in read.llrs.keys():
            if args.mods:
                if read.mod_names[loc] not in args.mods.split(','):
                    continue

            phase_mod = read.phase + '.' + read.mod_names[loc]
            orig_phase[phase_mod] = read.phase

            uuid = str(uuid4())
            meth_table[uuid]['loc'] = loc
            meth_table[uuid]['llr'] = read.llrs[loc]
            meth_table[uuid]['read'] = name
            meth_table[uuid]['phase'] = phase_mod
            meth_table[uuid]['call'] = read.meth_calls[loc]

            if phase_mod not in phase_order:
                phase_order.append(phase_mod)

    logger.info('phase + mod combinations: %s' % ','.join(phase_order))

    meth_table = pd.DataFrame.from_dict(meth_table).T
    meth_table['loc'] = pd.to_numeric(meth_table['loc'])
    meth_table['llr'] = pd.to_numeric(meth_table['llr'])

    # cpg space
    meth_table['orig_loc'] = meth_table['loc']
    meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

    coord_to_cpg = {}
    for orig_loc, new_loc in zip(meth_table['orig_loc'], meth_table['loc']):
        coord_to_cpg[orig_loc] = new_loc

    if args.highlight:
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_s, h_e = map(int, h.split('-'))
            h_start.append(h_s)
            h_end.append(h_e)

            h_start[-1] -= elt_start
            h_end[-1] -= elt_start

            h_cpg_start.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_start[-1]))])
            h_cpg_end.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_end[-1]))])


    fig = plt.figure()
    gs = gridspec.GridSpec(5,1,height_ratios=[1,5,1,3,3])

    if args.panelratios:
        p_ratios = list(map(int, args.panelratios.split(',')))
        gs = gridspec.GridSpec(5,1,height_ratios=p_ratios)

    phase_color = {}
    for i, phase in enumerate(phase_order):
        phase_color[phase] = sns.color_palette(args.phasepalette, n_colors=len(phase_order))[i]

    # plot genes
    ax0 = plt.subplot(gs[0])

    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.xaxis.set_ticks_position('top')

    genes = []

    if args.gtf is not None:
        gtf = pysam.Tabixfile(args.gtf)
        genes = build_genes(gtf, chrom, elt_start, elt_end)

    exon_patches = []
    tx_lines = []

    genes_of_interest = []

    if args.genes is not None:
        genes_of_interest = args.genes.strip().split(',')

    i = 0
    for ensg in genes:
        if genes_of_interest:
            if genes[ensg].name not in genes_of_interest:
                continue

        if genes[ensg].has_tx():

            tx_lines.append(matplotlib.lines.Line2D([genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start], [0.4+i, 0.4+i], zorder=1))

            print('transcript: %d-%d %s' % (genes[ensg].tx_start, genes[ensg].tx_end, genes[ensg].name))

        genes[ensg].merge_exons()
        for exon_start, exon_end in genes[ensg].exons:
            exon_len = exon_end - exon_start

            exon_patches.append(matplotlib.patches.Rectangle([exon_start-elt_start, i], exon_len, 0.8, edgecolor='#777777', facecolor='#ff4500', zorder=2))


        blocks_str = ','.join(['%d-%d' % (s,e) for s, e in genes[ensg].exons])
        print('%s exons: %s' % (genes[ensg].name, blocks_str))

        i += 1

    if i < 3:
        i = 3

    ax0.set_ylim(0,i)
    ax0.set_yticks([])

    for p in exon_patches:
        ax0.add_patch(p)

    for tx in tx_lines:
        ax0.add_line(tx)

    # per-read plot

    ax1 = plt.subplot(gs[1])
    ax1.set_xticks([])
    ax1.set_yticks([])

    readstack = dd(list)

    max_y  = 1
    pack_y = 1

    for phase in phase_order:
        fetch_reads_bam = pysam.AlignmentFile(args.bam)
        pos_cache = {}

        for read in fetch_reads_bam.fetch(chrom, elt_start, elt_end):
            if read.mapq < 10:
                continue

            if read.query_name not in pos_cache:
                pos_cache[read.query_name] = ([read.reference_start], [read.reference_end])
            else:
                pos_cache[read.query_name][0].append(read.reference_start)
                pos_cache[read.query_name][1].append(read.reference_end)

        for readname, read in methreads.items():
            if read.phase != orig_phase[phase]:
                continue

            read.ypos = max_y
            max_y += 1

            read.starts, read.ends = pos_cache[readname]
            readstack[phase].append(deepcopy(read)) # append deep copy to avoid reassigning ypos

        fetch_reads_bam.close()

    # read packing

    for phase in phase_order:
        reads = readstack[phase]

        y = dd(list)

        for read in readstack[phase]:
            y[read.ypos].append(read)

        for p in y:
            for q in y:
                if p == q:
                    continue

                for read_q in y[q]:
                        move = True

                        for read_p in y[p]:
                            if read_q.overlap(read_p):
                                move = False

                        if move:
                            y[p].append(read_q)
                            y[q].remove(read_q)

        for p in y:
            if len(y[p]) > 0:
                for read in sorted(y[p], key=lambda r: min(r.starts)):
                    read.ypos = pack_y
                pack_y += 1

    ax1.set_ylim(0,pack_y+1)

    for phase in phase_order:
        for read in readstack[phase]:
            for i in range(len(read.starts)):
                readline_start = max(read.starts[i], elt_start) - elt_start
                readline_end   = min(read.ends[i], elt_end) - elt_start

                for call_pos, call in read.meth_calls.items():
                    if call == -1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec=phase_color[phase], mfc='white', markersize=2, zorder=2)

                    if call == 1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec='black', mfc='black', markersize=2, zorder=2)


                ax1.add_line(matplotlib.lines.Line2D([readline_start, readline_end], [read.ypos-0.1, read.ypos-0.1], zorder=1, color=phase_color[phase], alpha=0.4))


    # plot correspondence between genome space and cpg space
    ax2 = plt.subplot(gs[2])
    ax3 = ax2.twiny()

    ax2.set_ylim(0,10)
    ax2.set_yticklabels([])

    x1 = []
    x2 = []

    step = int(args.cpgspace)

    for i, x in enumerate(meth_table['orig_loc']):
        if i in (0, len(meth_table['orig_loc'])-1):
            x2.append(x)
            x1.append(coord_to_cpg[x])

        elif i % step == 0:
            x2.append(x)
            x1.append(coord_to_cpg[x])

    
    ax2.vlines(x1, 0, 1, color='#777777', zorder=1)
    ax3.vlines(x2, 9, 10, color='#777777', zorder=1)

    if args.highlight:
        for i in range(len(h_start)):
            orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), h_end[i]-h_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=2)
            cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), h_cpg_end[i]-h_cpg_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=3)

            ax3.add_patch(orig_highlight_box)
            ax2.add_patch(cpg_highlight_box)

    for x1_x, x2_x in zip(x1, x2):
        link_end1 = (x1_x, 1)
        link_end2 = (x2_x, 9)

        l_col = '#777777'

        for i in range(len(h_start)):
            if x2_x >= h_start[i] and x2_x <= h_end[i]:
                l_col = h_colors[i]

        con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=l_col)
        ax3.add_artist(con)


    ax0.set_xlim(ax3.get_xlim()) # sync axes between orig coords and gtf plot
    ax1.set_xlim(ax3.get_xlim()) # sync axes between orig coords and gtf plot
    ax3.set_xticks([])

    n_ticks = 10
    tick_interval = (elt_end-elt_start)/n_ticks
    tick_list = list(range(int(ax0.get_xlim()[0]), int(ax0.get_xlim()[1]), int(tick_interval)))

    xt_labels = [str(int(t+elt_start)) for t in tick_list]
    ax0.set_xticks(tick_list) 
    ax0.set_xticklabels(xt_labels)

    # llr plot

    ax4 = plt.subplot(gs[3])

    ax4.axhline(y=2.5, c='k', linestyle='--',lw=1)
    ax4.axhline(y=0, c='#bbbbbb', linestyle='--',lw=1)
    ax4.axhline(y=-2.5, c='k', linestyle='--',lw=1)

    ax4 = sns.lineplot(x='loc', y='llr', hue='phase', data=meth_table, palette=args.phasepalette)

    ax4.set_xlim(ax2.get_xlim())

    # sliding window plot

    ax5 = plt.subplot(gs[4])

    order_stack = 1

    for phase in phase_order:
        windowed_methfrac, meth_n = slide_window_phased(meth_table, phase, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))

        smoothed_methfrac = smooth(np.asarray(list(windowed_methfrac.values())), window_len=int(args.smoothwindowsize))

        masked_segs = mask_methfrac(list(meth_n.values()), cutoff=args.maskcutoff)

        ax5.plot(list(windowed_methfrac.keys()), smoothed_methfrac, marker='', lw=4, color=phase_color[phase])

        order_stack += 1

        for seg in masked_segs:
            if len(seg) > 2:
                mf_seg = np.asarray(smoothed_methfrac)[seg]
                pos_seg = np.asarray(list(windowed_methfrac.keys()))[seg]
            
                ax5.plot(pos_seg, mf_seg, marker='', lw=4, color='#ffffff', alpha=0.8, zorder=order_stack)

                order_stack += 1


    ax5.set_xlim(ax2.get_xlim())
    ax5.set_ylim((-0.05,1.05))

    fig.set_size_inches(16, 8)

    imgtype = 'png'

    if args.svg:
        imgtype = 'svg'

    if args.ignore_tags:
        plt.savefig('%s.unphased.meth.%s' % (fn_prefix, imgtype), bbox_inches='tight')
        logger.info('plot saved to %s.unphased.meth.%s' % (fn_prefix, imgtype))
    else:
        plt.savefig('%s.phased.meth.%s' % (fn_prefix, imgtype), bbox_inches='tight')
        logger.info('plot saved to %s.phased.meth.%s' % (fn_prefix, imgtype))


def region(args):
    '''
    plotting function for windowed view of larger regions
    '''

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos = args.interval.split(':')
    start, end = map(int, pos.split('-'))

    assert start < end

    if end-start < 100000:
        logger.warning('locus smaller than 100kbp. "tmnt region" might work, but also have a look at "tmnt locus" for more detailed output')

    w_size = int((end-start)/int(args.windows))
    w_starts = list(range(start, end, w_size))

    data = {}
    mods = []

    with open(args.data) as _:
        for line in _:
            bam, meth_db = line.strip().split()
            data[bam] = meth_db

            mods += sorted(get_modnames(meth_db))

    mods = list(set(mods))

    logger.info('found mods: %s' % ','.join(mods))

    if args.mods:
        for mod in args.mods.split(','):
            assert mod in mods, 'mod %s not found' % mod
        mods = args.mods.split(',')

    pool = mp.Pool(processes=int(args.procs))

    results = dd(list)

    for seg_start in w_starts:
        for bam_fn, meth_fn in data.items():
            seg_end = seg_start + w_size
            seg_strand = '.'
            seg_name = bam_fn
            res = pool.apply_async(get_segmeth_calls, [bam_fn, mods, meth_fn, chrom, seg_start, seg_end, seg_name, seg_strand])
            results[mod].append(res)


def composite(args):
    '''
    plot composite methylation profiles relative to a consensus element
    '''
    te_ref_seq = single_seq_fa(args.teref)

    assert os.path.exists(args.ref + '.fai'), 'ref fasta must be indexed'

    fams = args.fams.split(',')

    logger.info('fams: %s' % ','.join(fams))

    segdata = pd.read_csv(args.segdata, sep='\t', header=0, index_col=0)

    sample = '.'.join(os.path.basename(args.bam).split('.')[:-1])

    if args.sample is not None:
        sample = args.sample

    assert sample+'_meth_calls' in segdata
    assert sample+'_unmeth_calls' in segdata

    useable = []

    for seg in segdata.index:
        use_seg = True

        if segdata['seg_name'].loc[seg] not in fams:
            use_seg = False
            continue

        if (segdata[sample+'_meth_calls'].loc[seg] + segdata[sample+'_unmeth_calls'].loc[seg]) < int(args.mincalls):
            use_seg = False
            continue

        if use_seg:
            useable.append(seg)

    if len(useable) > int(args.maxelts):
        useable = random.sample(useable, int(args.maxelts))

    segdata = segdata.loc[useable]

    logger.info('useable TEs: %d' % len(useable))


    pool = mp.Pool(processes=int(args.procs))

    results = []

    for seg in segdata.index:
        seg_chrom  = str(segdata['seg_chrom'].loc[seg])
        seg_start  = int(segdata['seg_start'].loc[seg])
        seg_end    = int(segdata['seg_end'].loc[seg])
        seg_strand = str(segdata['seg_strand'].loc[seg])
        seg_name   = str(segdata['seg_name'].loc[seg])

        res = pool.apply_async(get_meth_profile_composite, [args, seg_chrom, seg_start, seg_end, seg_name, seg_strand])

        results.append(res)

    fig = plt.figure()
    gs = gridspec.GridSpec(2,1,height_ratios=[1,8])

    # cpg
    ax0 = plt.subplot(gs[0])

    cpg_start = 0
    cpg_end = len(te_ref_seq)

    if args.start:
        cpg_start = int(args.start)

    if args.end:
        cpg_end = int(args.end)

    assert cpg_start < cpg_end

    ax0.set_xlim((cpg_start, cpg_end))

    box = matplotlib.patches.Rectangle([0, 0], cpg_end-cpg_start, 1.0, edgecolor='#555555', facecolor='#cfcfcf', zorder=1)
    ax0.add_patch(box)

    if args.blocks:
        with open(args.blocks) as blocks:
            for line in blocks:
                b_start, b_end, b_name, b_col = line.strip().split()
                b_start = int(b_start)
                b_end = int(b_end)

                box = matplotlib.patches.Rectangle([b_start, 0], b_end-b_start, 1.0, edgecolor='#555555', facecolor=b_col, zorder=2)
                ax0.add_patch(box)

    cpg_locs = []

    for i in range(len(te_ref_seq)-1):
        if i >= cpg_start and i <= cpg_end:
            if te_ref_seq[i] == 'C' and te_ref_seq[i+1] == 'G':
                cpg_locs.append(i)

    ax0.vlines(cpg_locs, 0, 1, lw=1, colors=('#FF4500'), zorder=3)

    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.xaxis.set_ticks_position('top')

    # wiggles
    ax1 = plt.subplot(gs[1])

    out_res = [] # cache for --outelts

    for res in results:
        coord_meth_pos, meth_profile = res.get()

        if len(coord_meth_pos) == 0:
            continue

        out_res.append((coord_meth_pos, meth_profile))


    if len(out_res) < int(args.outelts):
        logger.info('available profiles (%d) is less than --outelts (%d)' % (len(out_res), int(args.outelts)))
        args.outelts = len(out_res)

    for coord_meth_pos, meth_profile in random.sample(out_res, int(args.outelts)):
        ax1.plot(coord_meth_pos, meth_profile, lw=float(args.linewidth), alpha=float(args.alpha), color=args.colour)

    ax1.set_xlim((0,len(te_ref_seq)))
    ax1.set_ylim((-0.05,1.05))

    ax1.set_xlim((cpg_start, cpg_end))

    fig.set_size_inches(16, 6)

    fn_base = sample + '.' + '_'.join(fams)

    if args.svg:
        plt.savefig('%s.composite.svg' % fn_base, bbox_inches='tight')
        logger.info('plotted to %s.composite.svg' % fn_base)
    else:
        plt.savefig('%s.composite.png' % fn_base, bbox_inches='tight')
        logger.info('plotted to %s.composite.png' % fn_base)


def wgmeth(args):
    '''
    generates whole genome output in DSS or bedmethyl format
    '''

    meth_table = [None, None]
    seg_meth_table_store = dd(list)

    pool = mp.Pool(processes=int(args.procs))
    bin_size = int(args.binsize)

    mods = sorted(get_modnames(args.methdata))

    if args.mod not in mods:
        logger.warning('mod %s not in known mods for db: %s' % (args.mod, ','.join(mods)))
        sys.exit()

    if len(mods) > 1 and args.mod is None:
        logger.warning('more than one mod exists, need to pick one with -m/--mod: %s' % ','.join(mods))
        sys.exit()

    results = []

    with open(args.fai) as fai:
        for line in fai:
            chrom, chrlen = line.strip().split()[:2]
            chrlen = int(chrlen)

            for seg_start in range(0, chrlen, bin_size):
                seg_end = seg_start + bin_size

                if seg_end > chrlen:
                    seg_end = chrlen

                seg_start = int(seg_start)
                seg_end = int(seg_end)

                res = pool.apply_async(get_meth_calls_wg, [args.bam, args.methdata, chrom, seg_start, seg_end, args.phased, args.mod])

                results.append(res)


    for res in results:
        seg_meth_table = res.get()

        if seg_meth_table is None:
            continue

        if args.phased:
            for phase in (0,1):
                seg_meth_table_store[phase].append(pd.DataFrame.from_dict(seg_meth_table[phase]).T)

        else:
            seg_meth_table_store[0].append(pd.DataFrame.from_dict(seg_meth_table[0]).T)

    if args.phased:
        for phase in (0,1):
            meth_table[phase] = pd.concat(seg_meth_table_store[phase])

    else:
        meth_table[0] = pd.concat(seg_meth_table_store[0])

    if args.phased:
        for phase in (0,1):
            meth_table[phase]['pos'] = meth_table[phase].index
            meth_table[phase] = meth_table[phase].sort_values(['chr', 'pos'])

            outfn = None

            if args.dss:
                outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1]) + '.%s.phase_%d.DSS.txt' % (str(args.mod), phase)
            else:
                outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1]) + '.%s.phase_%d.methyl.bed' % (str(args.mod), phase)

            logger.info('writing %s' % outfn)

            if args.dss:
                meth_table[phase].to_csv(outfn, columns=['chr','pos','N','X'], index=False, sep='\t') # 1-based

            else:
                meth_table[phase]['score']  = meth_table[phase]['N']
                meth_table[phase]['score']  = meth_table[phase]['score'].where(meth_table[phase]['score'] <= 1000, 1000) # cap score at 1000
                meth_table[phase]['start']  = meth_table[phase]['pos']-1 # 0-based
                meth_table[phase]['end']    = meth_table[phase]['pos']
                meth_table[phase]['pct']    = meth_table[phase]['X']/meth_table[phase]['N']
                meth_table[phase]['name']   = '.'.join(args.methdata.split('.')[:-1])
                meth_table[phase]['strand'] = '.'
                meth_table[phase]['colour'] = '0,0,0'

                meth_table[phase].to_csv(outfn, columns=['chr','start','end','name', 'score', 'strand', 'start', 'end', 'colour', 'N', 'pct'], header=False, index=False, sep='\t')


    else:
        meth_table[0]['pos'] = meth_table[0].index
        meth_table[0] = meth_table[0].sort_values(['chr', 'pos'])

        outfn = None

        if args.dss:
            outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1]) + '%s.DSS.txt' % str(args.mod)
        else:
            outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1]) + '%s.methyl.bed' % str(args.mod)
        

        logger.info('writing %s' % outfn)
        if args.dss:
            meth_table[0].to_csv(outfn, columns=['chr','pos','N','X'], index=False, sep='\t')
        else:
            meth_table[0]['score']  = meth_table[0]['N']
            meth_table[0]['score']  = meth_table[0]['score'].where(meth_table[0]['score'] <= 1000, 1000)
            meth_table[0]['start']  = meth_table[0]['pos']-1
            meth_table[0]['end']    = meth_table[0]['pos']
            meth_table[0]['pct']    = meth_table[0]['X']/meth_table[0]['N']*100
            meth_table[0]['name']   = '.'.join(args.methdata.split('.')[:-1])
            meth_table[0]['strand'] = '.'
            meth_table[0]['colour'] = '0,0,0'

            meth_table[0].to_csv(outfn, columns=['chr','start','end','name', 'score', 'strand', 'start', 'end', 'colour', 'N', 'pct'], header=False, index=False, sep='\t')

def main(args):
    logger.info('starting tmnt with command: %s' % ' '.join(sys.argv))
    args.func(args)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='transposon methylation nanopore tools')
    subparsers = parser.add_subparsers(title="tool", dest="tool")
    subparsers.required = True

    parser_nanopolish = subparsers.add_parser('db-nanopolish')
    parser_megalodon  = subparsers.add_parser('db-megalodon')
    parser_guppy      = subparsers.add_parser('db-guppy')
    parser_segmeth    = subparsers.add_parser('segmeth')
    parser_segplot    = subparsers.add_parser('segplot')
    parser_locus      = subparsers.add_parser('locus')
    parser_haplocus   = subparsers.add_parser('haplocus')
    parser_region     = subparsers.add_parser('region')
    parser_composite  = subparsers.add_parser('composite')
    parser_wgmeth     = subparsers.add_parser('wgmeth')

    parser_segmeth.set_defaults(func=segmeth)
    parser_segplot.set_defaults(func=segplot)
    parser_locus.set_defaults(func=locus)
    parser_haplocus.set_defaults(func=haplocus)
    parser_region.set_defaults(func=region)
    parser_composite.set_defaults(func=composite)
    parser_wgmeth.set_defaults(func=wgmeth)
    parser_nanopolish.set_defaults(func=db_nanopolish)
    parser_megalodon.set_defaults(func=db_megalodon)
    parser_guppy.set_defaults(func=db_guppy)

    # options for methylation segments
    parser_segmeth.add_argument('-d', '--data', required=True, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_segmeth.add_argument('-i', '--intervals', required=True, help='.bed file')
    parser_segmeth.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_segmeth.add_argument('--excl_ambig', action='store_true', default=False)

    # options for methylation strip / violin plots
    parser_segplot.add_argument('-s', '--segmeth', required=True, help='output from segmeth.py')
    parser_segplot.add_argument('-m', '--samples', default=None, help='samples, comma delimited')
    parser_segplot.add_argument('-c', '--categories', default=None, help='categories, comma delimited, need to match seg_name column from input')
    parser_segplot.add_argument('-v', '--violin', default=False, action='store_true')
    parser_segplot.add_argument('-n', '--mincalls', default=10, help='minimum number of calls to include site (methylated + unmethylated) (default=10)')
    parser_segplot.add_argument('--width', default=12, help='figure width (default = 12)')
    parser_segplot.add_argument('--height', default=6, help='figure height (default = 6)')
    parser_segplot.add_argument('--pointsize', default=1, help='point size for scatterplot (default = 1)')
    parser_segplot.add_argument('--ymin', default=-0.15, help='ymin (default = -0.15)')
    parser_segplot.add_argument('--ymax', default=1.15, help='ymax (default = 1.15)')
    parser_segplot.add_argument('--tiltlabel', default=False, action='store_true')
    parser_segplot.add_argument('--svg', default=False, action='store_true')

    # options for locus-specific plots
    parser_locus.add_argument('-d', '--data', required=True, help='text file with .bam filename and corresponding methylation database per line (whitespace-delimited)')
    parser_locus.add_argument('-i', '--interval', required=True, help='chr:start-end')
    parser_locus.add_argument('-g', '--gtf', default=None, help='genes or intervals to display in gtf format')
    parser_locus.add_argument('-l', '--highlight', default=None, help='format: start-end')
    parser_locus.add_argument('-m', '--mods', default=None, help='mods, comma-delimited for >1 (default to all available mods)')
    parser_locus.add_argument('-s', '--slidingwindowsize', default=20, help='size of sliding window for meth frac (default=20)')
    parser_locus.add_argument('-t', '--slidingwindowstep', default=2, help='step size for meth frac (default=2)')
    parser_locus.add_argument('-p', '--panelratios',  default=None, help='Alter panel ratios: needs to be 5 comma-seperated integers. Default: 1,5,1,3,3')
    parser_locus.add_argument('--smoothwindowsize', default=8, help='size of window for smoothing (default=8)')
    parser_locus.add_argument('--maskcutoff', default=20, help='windowed read count masking cutoff (default=20)')
    parser_locus.add_argument('--cpgspace', default=10, help='spacing between links in top panel (default=10)')
    parser_locus.add_argument('--genes', default=None, help='genes of interest (comma delimited)')
    parser_locus.add_argument('--methcall_ymax', default=None)
    parser_locus.add_argument('--methcall_xmin', default=None)
    parser_locus.add_argument('--methcall_xmax', default=None)
    parser_locus.add_argument('--samplepalette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_locus.add_argument('--highlightpalette', default="Blues", help='colour palette name for highlights (default = "Blues")') 
    parser_locus.add_argument('--excl_ambig', action='store_true', default=False)
    parser_locus.add_argument('--unambig_highlight', action='store_true', default=False)
    parser_locus.add_argument('--svg', action='store_true')

    # options for haplotype locus plots
    parser_haplocus.add_argument('-b', '--bam', required=True, help='bam used for methylation calling')
    parser_haplocus.add_argument('-g', '--gtf', default=None, help='genes or intervals to display in gtf format')
    parser_haplocus.add_argument('-d', '--methdata', required=True, help='methylation database (.db file)')
    parser_haplocus.add_argument('-i', '--interval', required=True, help='chr:start-end')
    parser_haplocus.add_argument('-l', '--highlight', default=None, help='format: start-end')
    parser_haplocus.add_argument('-m', '--mods', default=None, help='mods, comma-delimited for >1 (default to all available mods)')
    parser_haplocus.add_argument('-s', '--slidingwindowsize', default=20, help='size of sliding window for smoothed plot (default=20)')
    parser_haplocus.add_argument('-t', '--slidingwindowstep', default=2, help='size of sliding window for smoothed plot (default=2)')
    parser_haplocus.add_argument('-p', '--panelratios',  default=None, help='Alter panel ratios: needs to be 5 comma-seperated integers. Default: 1,5,1,3,3')
    parser_haplocus.add_argument('--methcall_ymax', default=None)
    parser_haplocus.add_argument('--cpgspace', default=10, help='spacing between links in top panel (default=10)')
    parser_haplocus.add_argument('--smoothwindowsize', default=8, help='size of window for smoothing (default=8)')
    parser_haplocus.add_argument('--maskcutoff', default=20, help='windowed read count masking cutoff (default=20)')
    parser_haplocus.add_argument('--genes', default=None, help='genes of interest (comma delimited)')
    parser_haplocus.add_argument('--phasepalette', default="tab10", help='palette for phases (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_haplocus.add_argument('--highlightpalette', default="Blues", help='colour palette name for highlights (default = "Blues")') 
    parser_haplocus.add_argument('--excl_ambig', action='store_true', default=False)
    parser_haplocus.add_argument('--unambig_highlight', action='store_true', default=False)
    parser_haplocus.add_argument('--ignore_tags', action='store_true', default=False)
    parser_haplocus.add_argument('--tag_untagged', action='store_true', default=False)
    parser_haplocus.add_argument('--skip_callplot', action='store_true', default=False)
    parser_haplocus.add_argument('--skip_wiggle', action='store_true', default=False)
    parser_haplocus.add_argument('--svg', action='store_true', default=False)

    # options for region plots
    parser_region.add_argument('-i', '--interval', required=True, help='chrom:start-end')
    parser_region.add_argument('-d', '--data', required=True, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_region.add_argument('-w', '--windows', default=1000, help='set window count, default=1000')
    parser_region.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_region.add_argument('-m', '--mods', default=None, help='mods to consider (comma-delimited, default = all available)')
    parser_region.add_argument('--svg', action='store_true', default=False)

    # options for composite plots
    parser_composite.add_argument('-s', '--segdata', required=True, help='segmeth output')
    parser_composite.add_argument('-b', '--bam', required=True)
    parser_composite.add_argument('-m', '--meth', required=True, help='methylation database')
    parser_composite.add_argument('-f', '--fams', required=True, help='families, comma delimited, corresponds to values in seg_name column in segmeth output')
    parser_composite.add_argument('-r', '--ref', required=True, help='ref genome fasta')
    parser_composite.add_argument('-t', '--teref', required=True, help='TE ref fasta')
    parser_composite.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_composite.add_argument('-c', '--colour', default='#ff4500', help='colour (default: #ff4500')
    parser_composite.add_argument('-a', '--alpha', default=0.3, help='alpha (default: 0.3)')
    parser_composite.add_argument('-w', '--linewidth', default=1, help='line width (default: 1)')
    parser_composite.add_argument('-l', '--lenfrac', default=0.95, help='fraction of TE length that must align (default 0.95)')
    parser_composite.add_argument('--blocks', default=None, help='blocks to highlight (txt file with start, end, name, hex colour)')
    parser_composite.add_argument('--start', default=None, help='start plotting at this base (default None)')
    parser_composite.add_argument('--end', default=None, help='end plotting at this base (default None)')
    parser_composite.add_argument('--sample', default=None, help='specify sample name (default = infer from .bam)')
    parser_composite.add_argument('--mincalls', default=100, help='minimum call count to include elt (default = 100)')
    parser_composite.add_argument('--maxelts', default=300, help='maximum elements, if > max random.sample() (default = 300)')
    parser_composite.add_argument('--outelts', default=100, help='maximum output elements, if > max random.sample() (default = 300)')
    parser_composite.add_argument('--slidingwindowsize', default=10, help='size of sliding window for meth frac (default 10)')
    parser_composite.add_argument('--slidingwindowstep', default=1, help='step size for meth frac (default 1)')
    parser_composite.add_argument('--smoothwindowsize', default=8, help='size of window for smoothing (default 8)')
    parser_composite.add_argument('--globalign', action='store_true', default=False, help='experimental')
    parser_composite.add_argument('--excl_ambig', action='store_true', default=False)
    parser_composite.add_argument('--svg', action='store_true', default=False)

    # options for whole genome output
    parser_wgmeth.add_argument('-b', '--bam', required=True, help='bam used for methylation calling')
    parser_wgmeth.add_argument('-d', '--methdata', required=True, help='methylation database')
    parser_wgmeth.add_argument('-s', '--binsize', default=1000000, help='bin size for parallelisation, default = 1000000')
    parser_wgmeth.add_argument('-f', '--fai', required=True, help='fasta index (.fai)')
    parser_wgmeth.add_argument('-m', '--mod', default=None, help='output for specific mod (names vary, see output for hints)')
    parser_wgmeth.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_wgmeth.add_argument('--dss', default=False, action='store_true', help='output in DSS format (default = bedMethyl)')
    parser_wgmeth.add_argument('--phased', action='store_true', default=False, help='multiprocessing')

    # options for megalodon db
    parser_megalodon.add_argument('-m', '--methdata', required=True, help='whole genome nanopolish methylation output, can be comma-delimited')
    parser_megalodon.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_megalodon.add_argument('-p', '--minprob', default=0.8, help='probability threshold for calling modified or unmodified base (default = 0.8)')
    parser_megalodon.add_argument('-a', '--append', default=False, action='store_true', help='append to database')

    # options for guppy db
    parser_guppy.add_argument('-s', '--samplename', required=True, help='name for sample')
    parser_guppy.add_argument('-f', '--fast5', required=True, help='fast5 with called bases')
    parser_guppy.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_guppy.add_argument('-m', '--motif', required=True, help='motif e.g. G[A]TC or [C]G')
    parser_guppy.add_argument('-n', '--modname', required=True, help='mod name in guppy fast5 modified base alphabet (5mC, 6mA, etc)')
    parser_guppy.add_argument('-b', '--bam', required=True, help='.bam file containing alignments of reads from fast5')
    parser_guppy.add_argument('-r', '--ref', required=True, help='reference genome fasta (samtools faidx indexed)')
    parser_guppy.add_argument('--minprob', default=0.8, help='probability threshold for calling modified or unmodified base (default = 0.8)')
    parser_guppy.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_guppy.add_argument('--include_unmatched', action='store_true', default=False, help='include sites where read base does not match genome base')

    # options for nanopolish db
    parser_nanopolish.add_argument('-m', '--methdata', required=True, help='whole genome nanopolish methylation output, can be comma-delimited')
    parser_nanopolish.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_nanopolish.add_argument('-t', '--thresh', default=2.5, help='llr threshold (default = 2.5; if using --scalegroup the suggested setting is 2.0)')
    parser_nanopolish.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_nanopolish.add_argument('-s', '--scalegroup', default=False, action='store_true', help='scale threshold by number of CpGs in a group')
    parser_nanopolish.add_argument('-n', '--modname', default='CpG', help='modification type (tag if combining multiple mods, default = "CpG")')

    args = parser.parse_args()
    main(args)
