#!/home/taewing/miniconda3/bin/python

import os
import sys
import argparse
import gzip
import csv
import logging
import multiprocessing as mp
import matplotlib
import random
import itertools
import sqlite3

# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')

# Illustrator compatibility
new_rc_params = {'text.usetex': False, "svg.fonttype": 'none'}
matplotlib.rcParams.update(new_rc_params)

import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.patches import ConnectionPatch
import seaborn as sns

import numpy as np
import pandas as pd
import pysam
import scipy.stats as ss

import skbio.alignment as skalign
import skbio.sequence as skseq

from uuid import uuid4
from collections import defaultdict as dd
from collections import Counter
from itertools import product
from operator import itemgetter
from copy import deepcopy

from ont_fast5_api.fast5_interface import get_fast5_file
from bx.intervals.intersection import Intersecter, Interval

FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Read:
    def __init__(self, read_name, cpg_loc, stat, methcall, modname, phase=None):
        self.read_name  = read_name
        self.llrs       = {}
        self.meth_calls = {}
        self.phase      = phase
        self.mod_names  = {}
        self.call_count = 0

        # used for locus per-read plot
        self.ypos   = None
        self.starts = []
        self.ends   = []

        self.add_mod(cpg_loc, stat, methcall, modname)

    def add_mod(self, cpg_loc, stat, methcall, modname):
        assert methcall in (-1,0,1)
        self.llrs[cpg_loc]       = stat
        self.meth_calls[cpg_loc] = methcall
        self.mod_names[cpg_loc]  = modname

        if methcall != 0:
            self.call_count += 1

    def overlap(self, other):
        return min(max(self.ends), max(other.ends)) - max(min(self.starts), min(other.starts)) > 0
    
    def mod_count(self):
        return len([c for c in self.meth_calls.values() if c == 1])

    def site_count(self):
        return len(self.meth_calls)

    def mod_frac(self):
        if self.site_count() == 0:
            return 0

        return self.mod_count()/self.site_count()


class Gene:
    def __init__(self, ensg, name):
        self.ensg = ensg
        self.name = name
        self.tx_start = None
        self.tx_end = None
        self.cds_start = None
        self.cds_end = None
        self.exons = []

    def add_exon(self, block):
        assert len(block) == 2

        if block[0] < block[1]:
            block[0], block[1] = block[1], block[0]

        self.exons.append(block)
        self.exons = sorted(self.exons, key=itemgetter(0))

    def add_tx(self, block):
        assert len(block) == 2
        assert block[0] < block[1]
        if self.tx_start is None or self.tx_start > block[0]:
            self.tx_start = block[0]

        if self.tx_end is None or self.tx_end < block[1]:
            self.tx_end = block[1]

    def add_cds(self, block):
        assert len(block) == 2
        if block[0] > block[1]:
            logger.warning('CDS block start > end in gene %s' % self.ensg)
            return None

        if self.cds_start is None or self.cds_start > block[0]:
            self.cds_start = block[0]

        if self.cds_end is None or self.cds_end < block[1]:
            self.cds_end = block[1]

    def has_tx(self):
        return None not in (self.tx_start, self.tx_end)

    def has_cds(self):
        return None not in (self.cds_start, self.cds_end)

    def merge_exons(self):
        new_exons = []
        if len(self.exons) == 0:
            return

        last_block = self.exons[0]

        for block in self.exons[1:]:
            if min(block[1], last_block[1]) - max(block[0], last_block[0]) > 0: # overlap
                last_block = [min(block[0], last_block[0]), max(block[1], last_block[1])]

            else:
                new_exons.append(last_block)

            last_block = block

        new_exons.append(last_block)

        self.exons = new_exons


RC = {'A':'T', 'T':'A', 'C':'G', 'G':'C', 'N':'N'}

class Motif:
    def __init__(self, motif):

        self.motif = motif

        self.left_bases  = ''
        self.right_bases = ''
        self.key_base    = ''
        self.full_seq    = ''

        assert len(motif.split('[')) == 2, 'bad motif syntax: %s' % motif
        assert len(motif.split(']')) == 2, 'bad motif syntax: %s' % motif

        self.left_base   = motif.split('[')[0]
        self.right_bases = motif.split(']')[1]
        self.key_base    = motif.split('[')[1].split(']')[0]
        self.full_seq    = self.left_bases + self.key_base + self.right_bases
        
        assert self.key_base in 'ATCG', 'bad motif syntax: %s' % motif


    def match_seq(self, seq):
        ''' returns 0-based coords of key base for matches in seq '''

        sites = dd(list) 

        seq = list(seq.upper())


        for i in range(len(seq)-len(self.full_seq)):
            nmer = ''.join(seq[i:i+len(self.full_seq)])
            if nmer == self.full_seq:
                sites[i + len(self.left_bases)].append(self.motif)

        return sites
        

class Mod:
    def __init__(self, read_id, read_pos, read_base, sites, mods):
        self.read_id     = read_id
        self.read_pos    = read_pos
        self.read_base   = read_base
        self.sites       = sites
        self.mods        = mods
        self.chrom       = None
        self.genome_pos  = None
        self.genome_base = None
        self.aln_strand  = None

    def __str__(self):
        return '\t'.join((self.read_id, self.chrom, str(self.genome_pos), self.genome_base, self.aln_strand, str(self.read_pos), self.read_base, ','.join(self.sites), '\t'.join(map(str, self.mods))))


def base_field(mod_metadata):
    long_names = mod_metadata['modified_base_long_names'].split()

    bases = []

    i = 0
    alpha_field_name = None

    if 'output_alphabet' in mod_metadata:
        alpha_field_name = 'output_alphabet'

    elif 'modified_base_alphabet' in mod_metadata:
        alpha_field_name = 'modified_base_alphabet'

    for base in list(mod_metadata[alpha_field_name]):
        if base not in 'ACTG':
            bases.append(long_names[i])
            i += 1
        else:
            bases.append(base)

    return bases


def pandas_df(out, mod_base, can_base):
    data = dd(dict)

    header = []

    for i, line in enumerate(out):
        if i == 0:
            header = line.strip().split()
            continue

        for j, val in enumerate(line.strip().split('\t')):
            data[i][header[j]] = val

    data = pd.DataFrame.from_dict(data).T
    data = plot_data = pd.DataFrame(data.to_dict())

    data['can_log_prob'] = np.log((pd.to_numeric(data[can_base])+1)/256)
    data['mod_log_prob'] = np.log((pd.to_numeric(data[mod_base])+1)/256)
    data['llr'] = data['mod_log_prob'] - data['can_log_prob']

    return data


def guppy_f5_fetch(fast5, bam_db, args):
    bam = pysam.AlignmentFile(args.bam)
    ref = pysam.Fastafile(args.ref)

    bases = []
    modbases = dd(list)

    m = Motif(args.motif)

    out = []

    with get_fast5_file(args.fast5 + '/' + fast5, mode="r") as f5:
        for read_id in f5.get_read_ids():
            read = f5.get_read(read_id)
            latest_basecall = read.get_latest_analysis('Basecall_1D')

            mod_base_table = read.get_analysis_dataset(latest_basecall, 'BaseCalled_template/ModBaseProbs')
            called_base_table = read.get_analysis_dataset(latest_basecall, 'BaseCalled_template/Fastq')

            mod_table_path = '{}/BaseCalled_template/ModBaseProbs'.format(latest_basecall)
            called_table_path = '{}/BaseCalled_template/Fastq'.format(latest_basecall)

            mod_metadata = read.get_analysis_attributes(mod_table_path)
            called_metadata = read.get_analysis_attributes(called_table_path)

            if None in (mod_metadata, called_metadata):
                logger.error('fast5 %s does not seem to contain basecalling information' % fast5)
                return None

            if len(bases) == 0:
                bases = base_field(mod_metadata)

                mod_names = [b for b in bases if b not in ('ATCG')]

                if args.modname not in mod_names:
                    logger.error('mod "%s" not found in metadata. Available mod names: %s' % (args.modname, ','.join(mod_names)))
                    return None

                header = "readname\tchrom\tpos\tgenome_base\tstrand\tread_pos\tread_base\tmotif\t" + '\t'.join(bases) + "\n"

            seq = called_base_table.split('\n')[1]

            assert len(seq) == len(mod_base_table)

            sites = m.match_seq(seq)

            for i, mods in enumerate(mod_base_table):
                if i in sites:
                    modbases[read_id].append(Mod(read_id, i, seq[i], sites[i], mods))

    logger.info('processed %d reads from %s' % (len(modbases), fast5))

    logger.info('parsing alignments from %s' % bam_db)


    out.append(header)

    conn = sqlite3.connect(bam_db)
    c = conn.cursor()

    for readname in modbases:
        for row in c.execute("SELECT sam FROM bam WHERE readname='%s'" % readname):
            read = pysam.AlignedSegment.fromstring(row[0], bam.header)

            pos_lookup = {}
            for ap in read.get_aligned_pairs():
                if None not in ap:
                    pos_lookup[ap[0]] = ap[1]

            for modbase in modbases[read.query_name]:
                modbase.aln_strand = '+'
                aligned_read_pos = modbase.read_pos

                if read.is_reverse:
                    modbase.aln_strand = '-'
                    aligned_read_pos = (len(read.seq) - modbase.read_pos)-1 # SAM format stores read on + strand

                if aligned_read_pos in pos_lookup:
                    modbase.chrom = read.reference_name
                    modbase.genome_pos = pos_lookup[aligned_read_pos]
                    modbase.genome_base = ref.fetch(modbase.chrom, modbase.genome_pos, modbase.genome_pos+1)

                    modbase.genome_base = modbase.genome_base.upper()

                    if args.include_unmatched:
                        out.append(str(modbase) + '\n')
                        
                    else:
                        if modbase.aln_strand == '+' and modbase.genome_base == modbase.read_base:
                            out.append(str(modbase) + '\n')

                        if modbase.aln_strand == '-' and RC[modbase.genome_base] == modbase.read_base:
                            out.append(str(modbase) + '\n')

    
    data = pandas_df(out, args.modname, m.key_base)
    out_fn = args.fast5 + '/' + '.'.join(fast5.split('.')[:-1]) + '.gupmod.tsv'
    logger.info('writing output from %s to %s' % (fast5, out_fn))
    data.to_csv(out_fn, sep='\t', index=False)

    return out_fn


def exclude_ambiguous_reads(fn, chrom, start, end, min_mapq=10):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        p = read.get_reference_positions()
        if p[0] < start or p[-1] > end:
            if read.mapq >= min_mapq:
                reads.append(read.query_name)

    return reads


def get_ambiguous_reads(fn, chrom, start, end, min_mapq=10, w=50):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        p = read.get_reference_positions()
        if read.mapq < min_mapq or (p[0] > start-w and p[-1] < end+w):
            reads.append(read.query_name)

    return reads


def get_reads(fn, chrom, start, end, min_mapq=10):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        if read.mapq >= min_mapq:
            reads.append(read.query_name)

    return reads


def get_phased_reads(fn, chrom, start, end, min_mapq=10, tag_untagged=False, ignore_tags=False):
    reads = {}

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        if read.mapq >= min_mapq:    
            phase = None

            if tag_untagged or ignore_tags:
                phase = 'unphased'

            HP = None
            PS = None

            if not ignore_tags:
                for tag in read.get_tags():
                    if tag[0] == 'HP':
                        HP = tag[1]
                    if tag[0] == 'PS':
                        PS = tag[1]

            if None not in (HP, PS):
                phase = str(PS) + ':' + str(HP)

            reads[read.query_name] = phase

    return reads


def single_seq_fa(fn):
    with open(fn, 'r') as fa:
        seq   = ''
        for line in fa:
            if line.startswith('>'):
                assert seq == '', 'input fa must have only one entry'
            else:
                seq = seq + line.strip()

    return seq


def rc(dna):
    ''' reverse complement '''
    complements = str.maketrans('acgtrymkbdhvACGTRYMKBDHV', 'tgcayrkmvhdbTGCAYRKMVHDB')
    return dna.translate(complements)[::-1]


def get_modnames(meth_db):
    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    mod_names = []

    for row in c.execute("SELECT DISTINCT mod FROM modnames"):
        mod_names.append(row[0])

    return mod_names


def get_cutoffs(meth_db, mod):
    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    for row in c.execute("SELECT upper,lower FROM cutoffs WHERE modname='%s'" % mod):
        return row


def densecall_filter(reads, max_density=.7):
    ''' reads is a dict of Read objects with modified base calls '''
    out = {}
    filtered_count = 0

    for name, read in reads.items():
        if read.mod_frac() <= max_density:
            out[name] = read
        else:
            filtered_count += 1
    
    logger.info('filtered %d reads via --max_read_density %f' % (filtered_count, max_density))

    return out


def get_segmeth_calls(args, bam_fn, mod_names, meth_db, chrom, seg_start, seg_end, seg_name, seg_strand):
    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    reads = []
    if hasattr(args, 'excl_ambig') and args.excl_ambig:
        reads = exclude_ambiguous_reads(bam_fn, chrom, seg_start, seg_end)
    else:
        reads = get_reads(bam_fn, chrom, seg_start, seg_end)

    reads = list(set(reads))

    seg_reads = {}

    for index in reads:
        for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

            cg_chrom, cg_start, stat, methstate, modname = row

            if chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start > seg_end:
                continue

            cg_seg_start = cg_start - seg_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

    seg_result = {}

    for modname in mod_names:
        seg_meth_calls = dd(int)

        for name, read in seg_reads.items():
            for loc, call in read.meth_calls.items():
                if read.mod_names[loc] == modname:
                    seg_meth_calls[call] += 1

        seg_result[modname] = seg_meth_calls

    read_count = len(seg_reads)

    return seg_result, (chrom, seg_start, seg_end, seg_name, seg_strand, read_count)


def get_meth_locus(args, bam, meth_db, mod, phase=None):
    # used for locus plots

    if phase:
        logger.info('fetching reads from %s with for mod %s on phase %s' % (bam, mod, phase))
    else:
        logger.info('fetching reads from %s with for mod %s' % (bam, mod))

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos = args.interval.split(':')
    elt_start, elt_end = pos.split('-')

    elt_start = int(elt_start.replace(',',''))
    elt_end = int(elt_end.replace(',',''))

    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    # get list of relevant reads (exludes reads not anchored outside interval)
    reads = []
    if hasattr(args, 'excl_ambig') and args.excl_ambig:
        reads = exclude_ambiguous_reads(bam, chrom, elt_start, elt_end)
    else:
        reads = get_reads(bam, chrom, elt_start, elt_end)
        
    reads = list(set(reads))

    if phase:
        phased_reads_dict = get_phased_reads(bam, chrom, elt_start, elt_end, tag_untagged=(phase=='unphased'))
        reads = [r for r in reads if r in phased_reads_dict and phased_reads_dict[r] == phase]

    if hasattr(args, 'unambig_highlight') and hasattr(args, 'highlight'): 
        if args.unambig_highlight and args.highlight:
            h_coords = []
            for h in args.highlight.split(','):
                if ':' in h:
                    h = h.split(':')[-1]
                    
                h_coords += map(int, h.split('-'))

            h_coords.sort()

            h_start, h_end = h_coords[0], h_coords[-1]

            excl_reads = get_ambiguous_reads(bam, chrom, h_start, h_end)

            new_reads = []
            for read in reads:
                if read not in excl_reads:
                    new_reads.append(read)

            reads = new_reads

    seg_reads = {}

    for index in reads:
        for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

            cg_chrom, cg_start, stat, methstate, modname = row

            if methstate == 0:
                continue

            if modname != mod:
                continue

            if chrom != cg_chrom:
                continue

            if cg_start < elt_start or cg_start > elt_end:
                continue

            cg_seg_start = cg_start - elt_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

    return seg_reads


def get_meth_profile_composite(args, seg_chrom, seg_start, seg_end, seg_name, seg_strand, use_mod):
    logger.info('profiling %s %s:%d-%d:%s:%s' % (seg_name, seg_chrom, seg_start, seg_end, seg_strand, use_mod))

    te_ref_seq = single_seq_fa(args.teref)
    ref = pysam.Fastafile(args.ref)

    conn = sqlite3.connect(args.meth)
    c = conn.cursor()

    reads = []
    if args.excl_ambig:
        reads = exclude_ambiguous_reads(args.bam, seg_chrom, seg_start, seg_end)
    else:
        reads = get_reads(args.bam, seg_chrom, seg_start, seg_end)

    reads = list(set(reads))

    seg_reads = {}

    for index in reads:
        for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

            cg_chrom, cg_start, stat, methstate, modname = row

            if seg_chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start > seg_end:
                continue

            if modname != use_mod:
                continue

            cg_seg_start = cg_start - seg_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

    meth_table = dd(dict)
    sample = '.'.join(args.bam.split('.')[:-1])

    for name, read in seg_reads.items():
        for loc in read.llrs.keys():
            uuid = str(uuid4())
            meth_table[uuid]['loc'] = loc
            meth_table[uuid]['llr'] = read.llrs[loc]
            meth_table[uuid]['read'] = name
            meth_table[uuid]['sample'] = sample
            meth_table[uuid]['call'] = read.meth_calls[loc]

    meth_table = pd.DataFrame.from_dict(meth_table).T
    meth_table['loc'] = pd.to_numeric(meth_table['loc'])
    meth_table['llr'] = pd.to_numeric(meth_table['llr'])

    meth_table['orig_loc'] = meth_table['loc']
    meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

    coord_to_cpg = {}
    cpg_to_coord = {}
    for orig_loc, new_loc in zip(meth_table['orig_loc'], meth_table['loc']):
        coord_to_cpg[orig_loc] = new_loc
        cpg_to_coord[new_loc]  = orig_loc

    windowed_methfrac, meth_n = slide_window(meth_table, sample, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))

    if len(windowed_methfrac) <= int(args.smoothwindowsize):
        logger.warning('too few sites after windowing: %s:%d-%d' % (seg_chrom, seg_start, seg_end))
        return [], []

    smoothed_methfrac = smooth(np.asarray(list(windowed_methfrac.values())), window_len=int(args.smoothwindowsize), window=args.smoothfunc)

    coord_meth_pos = []

    cpg_meth_pos = list(windowed_methfrac.keys())

    for cpg in cpg_meth_pos:
        if cpg in cpg_to_coord:
            if seg_strand == '+':
                coord_meth_pos.append(cpg_to_coord[cpg])
            if seg_strand == '-':
                coord_meth_pos.append((seg_end-seg_start)-cpg_to_coord[cpg]-2)

    # alignment to ref elt

    elt_seq = ref.fetch(seg_chrom, seg_start, seg_end)

    if seg_strand == '-':
        elt_seq = rc(elt_seq)

    te_ref_seq = te_ref_seq.upper()
    elt_seq = elt_seq.upper()

    s_ref = skseq.DNA(te_ref_seq)
    s_elt = skseq.DNA(elt_seq)

    aln_res = []

    try:
        if args.globalign:
            aln_res = skalign.global_pairwise_align_nucleotide(s_ref, s_elt)
        else:
            aln_res = skalign.local_pairwise_align_ssw(s_ref, s_elt)
    except IndexError: # scikit-bio throws this if no bases align  >:|
        logger.warning('no align on seg: %s:%d-%d' % (seg_chrom, seg_start, seg_end))
        return [], []
    
    coord_ref, coord_elt = aln_res[2]
    
    len_ref = coord_ref[1] - coord_ref[0]
    len_elt = coord_elt[1] - coord_elt[0]

    if len_ref / len(te_ref_seq) < float(args.lenfrac):
        logger.warning('ref align too short on seg: %s:%d-%d (%f)' % (seg_chrom, seg_start, seg_end, len_ref / len(te_ref_seq)))
        return [], []

    if len_elt / len(elt_seq) < float(args.lenfrac):
        logger.warning('elt align too short on seg: %s:%d-%d (%f)' % (seg_chrom, seg_start, seg_end, len_elt / len(elt_seq)))
        return [], []

    tab_msa = aln_res[0]

    elt_to_ref_coords = {}

    pos_ref = coord_ref[0]
    pos_elt = coord_elt[0]

    for pos in tab_msa.iter_positions():
        pos = list(pos)
        b_ref = pos[0]
        b_elt = pos[1]

        if '-' not in pos:
            elt_to_ref_coords[pos_elt] = pos_ref
            pos_ref += 1
            pos_elt += 1

        if b_elt == '-':
            pos_ref += 1

        if b_ref == '-':
            elt_to_ref_coords[pos_elt] = 'na'
            pos_elt += 1

    revised_coord_meth_pos = []
    meth_profile = []

    for pos, meth in zip(coord_meth_pos, smoothed_methfrac):
        if pos not in elt_to_ref_coords:
            continue

        revised_pos = elt_to_ref_coords[pos]

        if revised_pos != 'na':
            revised_coord_meth_pos.append(revised_pos)
            meth_profile.append(meth)

    #print(revised_coord_meth_pos, seg_name, seg_chrom, seg_start, seg_end, seg_strand, use_mod)

    return revised_coord_meth_pos, meth_profile


def get_meth_calls_wg(args, bam_fn, meth_fn, chrom, seg_start, seg_end, phased, mod):
    conn = sqlite3.connect(meth_fn)
    c = conn.cursor()

    reads = get_phased_reads(bam_fn, chrom, seg_start, seg_end, tag_untagged=phased)

    seg_reads = {}

    for index in reads:
        for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

            cg_chrom, cg_start, stat, methstate, modname = row

            if mod is not None and mod != modname:
                continue

            if chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start > seg_end:
                continue

            cg_seg_start = cg_start - seg_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname, phase=reads[index])
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

    meth_data = {}
    meth_data[1] = dd(list)
    meth_data[2] = dd(list)

    meth_table = {}
    meth_table[1] = dd(dict)
    meth_table[2] = dd(dict)

    if phased:
        for name, read in seg_reads.items():
            if read.phase is None:
                continue

            if read.phase == 'unphased':
                continue

            if read.phase.split(':')[1] not in ['1','2']:
                continue

            for loc in read.llrs.keys():
                phase = int(read.phase.split(':')[1])
                assert phase in (1,2)

                meth_data[phase][loc].append(read.meth_calls[loc])


        for phase in (1,2):
            for loc in meth_data[phase]:
                pos = loc+seg_start
                N = len([call for call in meth_data[phase][loc] if call != 0]) # call 0 == ambiguous
                X = len([call for call in meth_data[phase][loc] if call == 1]) # call 1 == methylated

                if N > 0:
                    meth_table[phase][pos]['chr'] = chrom
                    meth_table[phase][pos]['N'] = N 
                    meth_table[phase][pos]['X'] = X 

    else:
        # pass through as phase 1 if not analysing phases
        phase = 1

        for name, read in seg_reads.items():
            for loc in read.llrs.keys():
                meth_data[phase][loc].append(read.meth_calls[loc])

        for loc in meth_data[phase]:
            pos = loc+seg_start
            N = len([call for call in meth_data[phase][loc] if call != 0])
            X = len([call for call in meth_data[phase][loc] if call == 1])

            if N > 0:
                meth_table[phase][pos]['chr'] = chrom
                meth_table[phase][pos]['N'] = N 
                meth_table[phase][pos]['X'] = X 

    return [meth_table[1], meth_table[2]]


def slide_window(meth_table, sample, width=20, slide=2):
    # used for locus plots, composite plots
    midpt_min = min(meth_table['loc'])
    midpt_max = max(meth_table['loc'])

    win_start = int(midpt_min - width/2)
    win_end = win_start + width

    meth_frac = {}
    meth_n = {}

    while int((win_start+win_end)/2) < midpt_max:
        win_start += slide
        win_end += slide

        meth_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == 1)])
        unmeth_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == -1)])

        midpt = int((win_start+win_end)/2)

        if meth_count + unmeth_count > 0:
            meth_frac[midpt] = meth_count/(meth_count+unmeth_count)
            meth_n[midpt] = meth_count+unmeth_count

    return meth_frac, meth_n


def slide_window_phased(meth_table, phase, width=20, slide=2):
    midpt_min = min(meth_table['loc'])
    midpt_max = max(meth_table['loc'])

    win_start = int(midpt_min - width/2)
    win_end = win_start + width

    meth_frac = {}
    meth_n = {}

    while int((win_start+win_end)/2) < midpt_max:
        win_start += slide
        win_end += slide

        meth_count = len(meth_table.loc[(meth_table['phase'] == phase) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == 1)])
        unmeth_count = len(meth_table.loc[(meth_table['phase'] == phase) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == -1)])

        midpt = int((win_start+win_end)/2)

        if meth_count + unmeth_count > 0:
            meth_frac[midpt] = meth_count/(meth_count+unmeth_count)
            meth_n[midpt] = meth_count+unmeth_count

    return meth_frac, meth_n


def smooth(x, window_len=8, window='hanning'):
    # used for locus plots
    ''' modified from scipy cookbook: https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html '''

    assert window_len % 2 == 0, '--smoothwindowsize must be an even number'
    assert x.ndim == 1
    assert x.size > window_len, 'fewer data points than window_len'

    if window_len<3:
        return x

    assert window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']

    s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]]
    
    if window == 'flat': #moving average
        w=np.ones(window_len,'d')
    else:
        w=eval('np.'+window+'(window_len)')

    y=np.convolve(w/w.sum(),s,mode='valid')

    return y[(int(window_len/2)-1):-(int(window_len/2))]


def mask_methfrac(data, cutoff=20):
    # used for locus plots
    data = np.asarray(data)
    data = data > int(cutoff)

    segs = []

    in_seg = False
    seg_start = 0

    for i in range(len(data)):
        if data[i]:
            if in_seg:
                segs.append(list(range(seg_start, i)))

            in_seg = False

        else:
            if not in_seg:
                seg_start = i

            in_seg = True

    if in_seg:
        segs.append(list(range(seg_start, len(data))))

    return segs


def build_genes(gtf, chrom, start, end):
    # used for locus plots
    genes = {}

    for line in gtf.fetch(chrom, start, end):

        chrom, source, feature, start, end, score, strand, frame, attribs = line.split('\t')

        block = [int(start), int(end)]

        attribs = attribs.strip()

        attr_dict = {}

        for attrib in attribs.split(';'):
            if attrib:
                key, val = attrib.strip().split()[:2]
                key = key.strip()
                val = val.strip().strip('"')
                attr_dict[key] = val

        if 'gene_id' not in attr_dict:
            continue

        if 'gene_name' not in attr_dict:
            continue

        ensg = attr_dict['gene_id']
        name = attr_dict['gene_name']

        if ensg not in genes:
            genes[ensg] = Gene(ensg, name)

        if feature == 'exon':
            genes[ensg].add_exon(block)

        if feature == 'CDS':
            genes[ensg].add_cds(block)

        if feature == 'transcript':
            genes[ensg].add_tx(block)

    return genes


def sample_db(db, mod, n=1000000):
    conn = sqlite3.connect(db)
    c = conn.cursor()

    logger.info('sample %d values from %s where modname="%s"' % (n, db, mod))
    return np.asarray(c.execute('SELECT stat FROM methdata WHERE modname="%s" ORDER BY RANDOM() LIMIT %d' % (mod, n)).fetchall()).flatten()


## tools

def scoredist(args):
    sample = {}
    model = {}

    meth_cutoffs = []
    unmeth_cutoffs = []

    avail_mods = []

    for db in args.db.split(','):
        avail_mods += get_modnames(db)

    avail_mods = list(set(avail_mods))

    if args.mod not in avail_mods:
        logger.warning('mod %s not found, available mods: %s' % (args.mod, ','.join(avail_mods)))
        sys.exit()

    for db in args.db.split(','):
        assert os.path.exists(db), '%s not found' % db
        upper, lower = get_cutoffs(db, args.mod)
        meth_cutoffs.append(upper)
        unmeth_cutoffs.append(lower)

        sample[db] = sample_db(db, args.mod, n=int(args.n))

    meth_cutoffs = list(set(meth_cutoffs))
    unmeth_cutoffs = list(set(unmeth_cutoffs))

    if len(meth_cutoffs) > 1:
        logger.warning('multipe upper (methylation) cutoffs:' + ','.join(map(str, meth_cutoffs)))

    if len(unmeth_cutoffs) > 1:
        logger.warning('multipe upper (unmethylation) cutoffs:' + ','.join(map(str, unmeth_cutoffs)))

    fig = plt.figure()
    ax = fig.add_subplot(111)
    dens = sns.kdeplot(data=sample)

    top = ax.get_ylim()[1]

    ax.vlines(meth_cutoffs, ymin=0, ymax=top, colors='black', linestyles='dashed')
    ax.vlines(unmeth_cutoffs, ymin=0, ymax=top, colors='black', linestyles='dashed')

    xmin, xmax = ax.get_xlim()

    if args.xmin:
        xmin = float(args.xmin)
    
    if args.xmax:
        xmax = float(args.xmax)

    ax.set_xlim(xmin, xmax)

    out_fn = '_'.join(args.db.split(',')) + '.scoredist'

    if args.svg:
        out_fn += '.svg'
    else:
        out_fn += '.png'

    logger.info('plot written to %s' % out_fn)
    plt.savefig(out_fn)


def adjustcutoffs(args):
    avail_mods = get_modnames(args.db)

    if args.mod not in avail_mods:
        logger.warning('mod %s not found, available mods: %s' % (args.mod, ','.join(avail_mods)))
        sys.exit()

    conn = sqlite3.connect(args.db)
    c = conn.cursor()

    logger.info('%s: reset methylation states to 0 for mod %s' % (args.db, args.mod))
    c.execute("UPDATE methdata SET methstate=0 where modname='%s'" % args.mod)

    logger.info('%s: mark sites with stat > %f methylated (1) for mod %s' % (args.db, float(args.methylated), args.mod))
    c.execute('UPDATE methdata SET methstate=1 where stat > %f and modname="%s"' % (float(args.methylated), args.mod))

    logger.info('%s: mark sites with stat < %f unmethylated (1) for mod %s' % (args.db, float(args.unmethylated), args.mod))
    c.execute('UPDATE methdata SET methstate=-1 where stat < %f and modname="%s"' % (float(args.unmethylated), args.mod))

    logger.info('%s: update cutoff table for mod %s' % (args.db, args.mod))
    c.execute('UPDATE cutoffs SET upper=%f WHERE modname="%s"' % (float(args.methylated), args.mod))
    c.execute('UPDATE cutoffs SET lower=%f WHERE modname="%s"' % (float(args.unmethylated), args.mod))

    conn.commit()

def db_guppy(args):
    assert os.path.exists(args.fast5), 'path not found: %s' % args.fast5

    bam_db = args.samplename + '.bamcache.db'

    if os.path.exists(bam_db):
        logger.info('using existing bam cache db %s' % bam_db)


    else:
        conn = sqlite3.connect(bam_db)
        c = conn.cursor()

        logger.info('caching %s to %s' % (args.bam, bam_db))
        c.execute('''CREATE TABLE bam (readname text, sam text)''')
        c.execute('''CREATE INDEX read_index ON bam(readname)''')

        for bam_fn in args.bam.split(','):
            bam = pysam.AlignmentFile(bam_fn)

            commit_interval = 100000

            for i, read in enumerate(bam.fetch(), 1):
                if read.is_secondary or read.is_supplementary or read.is_duplicate:
                    continue

                read.query_qualities = None

                read_data = (read.query_name, read.to_string())
                c.execute("INSERT INTO bam VALUES ('%s', '%s')" % read_data)

                if i % commit_interval == 0:
                    logger.info('commiting %d records to %s...' % (commit_interval, bam_db))
                    conn.commit()

            logger.info('commiting remaining records to %s...' % bam_db)

        conn.commit()
        conn.close()

    fast5s = []

    for fn in os.listdir(args.fast5):
        if fn.endswith('.fast5'):
            fast5s.append(fn)

    logger.info('found %d fast5 files in %s' % (len(fast5s), args.fast5))

    logger.info('fetching base modifications...')

    pool = mp.Pool(processes=int(args.procs))

    results = []

    for fast5 in fast5s:
        res = pool.apply_async(guppy_f5_fetch, [fast5, bam_db, args])
        results.append(res)


    outfiles = []

    for res in results:
        outfiles.append(res.get())

    db_fn = args.samplename + '.guppy.db'

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    conn = sqlite3.connect(db_fn)
    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for fn in outfiles:
        if fn is None:
            continue

        with open(fn) as tsv:
            logger.info('loading %s into %s' % (fn, db_fn))
            csv_reader = csv.DictReader(tsv, delimiter='\t')

            for row in csv_reader:
                readname = row['readname']
                chrom    = row['chrom']
                pos      = int(row['pos'])
                strand   = row['strand']

                mod_prob = float(row['mod_log_prob'])
                can_prob = float(row['can_log_prob'])
                minprob  = float(args.minprob)
                mod_base = args.modname

                llr = float(row['mod_log_prob']) - float(row['can_log_prob'])

                # adjust position of - strand calls to match nanopolish scheme
                if strand == '-':
                    pos -= 1

                methcall = 0

                if mod_prob >= minprob:
                    methcall = 1

                if can_prob >= minprob:
                    methcall = -1

                assert (mod_prob + can_prob) < 1.01 # allow for some rounding error

                ins_data = (chrom, pos, strand, readname, llr, methcall, mod_base)
                c.execute("INSERT INTO methdata VALUES ('%s', %d, '%s', '%s', %.2f, %d, '%s')" % ins_data)

            c.execute("INSERT INTO modnames VALUES ('%s')" % args.modname)

            if not args.append:
                c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (llr, -1*llr, args.modname))

            logger.info('commiting records from %s to %s' % (fn, db_fn))
            conn.commit()

    conn.close()        
    logger.info('finished.')


def db_megalodon(args):
    basename = '.'.join(args.methdata.split('.')[:-1])
    if basename == 'per_read_modified_base_calls' and not args.db:
        sys.exit('default megalodon filename (per_read_modified_base_calls) is not informative: please specify a database name with --db')    

    db_fn = basename + '.megalodon.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        csv_reader = csv.DictReader(methdata, delimiter='\t')

        progress_interval = 1000000

        modnames = {}

        ins_data = []

        minprob = float(args.minprob)

        for i, row in enumerate(csv_reader):
            readname = row['read_id']
            chrom    = row['chrm']
            pos      = int(row['pos'])
            strand   = row['strand']
            mod_prob = float(row['mod_log_prob'])
            can_prob = float(row['can_log_prob'])
            
            mod_base = row['mod_base']

            modnames[mod_base] = True

            lpr = float(row['mod_log_prob']) - float(row['can_log_prob'])

            # adjust position of - strand calls to match nanopolish scheme
            if strand == '-':
                pos -= 1

            methcall = 0

            if np.exp(mod_prob) >= minprob:
                methcall = 1

            if np.exp(can_prob) >= minprob:
                methcall = -1

            ins_data.append((chrom, pos, strand, readname, lpr, methcall, mod_base))

            if i % progress_interval == 0:
                conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
                ins_data = []
                logger.info('processed %d records from %s' % (i, tsv))

        for mod in modnames:
            c.execute("INSERT INTO modnames VALUES ('%s')" % mod)

        lpr_cutoff = np.log(minprob) - np.log(1-minprob)

        c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (lpr_cutoff, -1*lpr_cutoff, mod))

        if len(ins_data) > 0:
            conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
            ins_data = []

        logger.info('commiting records from %s to %s' % (tsv, db_fn))
        conn.commit()

    conn.close()


def db_nanopolish(args):
    basename = '.'.join(args.methdata.split('.')[:-1])

    if basename.endswith('.tsv'):
        basename = '.'.join(basename.split('.')[:-1])

    db_fn = basename + '.nanopolish.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        csv_reader = csv.DictReader(methdata, delimiter='\t')

        progress_interval = 500000

        for i, row in enumerate(csv_reader, 1):
            r_start  = int(row['start'])
            llr      = float(row['log_lik_ratio'])
            seq      = row['sequence']
            mod_base = args.modname

            if args.scalegroup:
                llr = llr/float(row['num_motifs'])

            methcall = 0

            if llr > float(args.thresh):
                methcall = 1

            elif llr < float(args.thresh)*-1:
                methcall = -1

            # get per-CG position (nanopolish/calculate_methylation_frequency.py)
            cg_pos = seq.find("CG")
            first_cg_pos = cg_pos

            while cg_pos != -1:
                cg_start = r_start + cg_pos - first_cg_pos
                cg_pos = seq.find("CG", cg_pos + 1)

                ins_data = (row['chromosome'], cg_start, row['strand'], row['read_name'], llr, methcall, mod_base)
                c.execute("INSERT INTO methdata VALUES ('%s', %d, '%s', '%s', %.2f, %d, '%s')" % ins_data)

            if i % progress_interval == 0:
                logger.info('processed %d records from %s' % (i, tsv))

        c.execute("INSERT INTO modnames VALUES ('%s')" % args.modname)

        c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (llr, -1*llr, args.modname))

        logger.info('commiting records from %s to %s' % (tsv, db_fn))

        conn.commit()

    conn.close()


def segmeth(args):
    '''
    segment methylation stats over genomic intervals
    '''

    stats = [
    'meth_calls',
    'unmeth_calls',
    'no_calls',
    'methfrac',
    'readcount'
    ]

    mod_names = []

    data = {}

    with open(args.data) as _:
        for line in _:
            bam, meth = line.strip().split()
            data[bam] = meth

            for m in get_modnames(meth):
                mod_names.append(m)

    mod_names = list(set(mod_names))

    logger.info('found mod names: %s' % ','.join(mod_names))

    base_names = ['.'.join(bam.split('.')[:-1]) for bam in data]

    data_basename = '.'.join(args.data.split('.')[:-1])
    ivl_basename = '.'.join(args.intervals.split('.')[:-1])
    outfn = '.'.join((ivl_basename, data_basename, 'segmeth.tsv'))

    if args.excl_ambig:
        outfn = '.'.join((ivl_basename, data_basename, 'excl_ambig', 'segmeth.tsv'))


    out = open(outfn, 'w')

    logger.info('segmeth output filename: %s' % outfn)

    pool = mp.Pool(processes=int(args.procs))

    results = []

    for bam_fn, meth_fn in data.items():
        base_name = '.'.join(bam_fn.split('.')[:-1])

        with open(args.intervals) as _:
            for line in _:
                c = line.strip().split()
                chrom, seg_start, seg_end = c[:3]

                seg_name = 'NA'
                seg_strand = 'NA'

                if len(c) > 3:
                    seg_name = c[3]

                if len(c) > 4:
                    seg_strand = c[4]

                seg_start = int(seg_start)
                seg_end = int(seg_end)

                res = pool.apply_async(get_segmeth_calls, [args, bam_fn, mod_names, meth_fn, chrom, seg_start, seg_end, seg_name, seg_strand])

                results.append((res, base_name))


    meth_segs = dd(dict)

    for res, base_name in results:
        meth_result, seg = res.get()

        if meth_result is None:
            continue

        seg_id = '%s:%d-%d' % seg[:3]

        seg_chrom, seg_start, seg_end, seg_name, seg_strand, read_count = map(str, seg)

        meth_segs[seg_id]['seg_id']     = seg_id
        meth_segs[seg_id]['seg_chrom']  = seg_chrom
        meth_segs[seg_id]['seg_start']  = seg_start
        meth_segs[seg_id]['seg_end']    = seg_end
        meth_segs[seg_id]['seg_name']   = seg_name
        meth_segs[seg_id]['seg_strand'] = seg_strand

        if seg_name == 'NA':
            meth_segs[seg_id]['seg_name'] = 'NoName'

        if seg_strand == 'NA':
            meth_segs[seg_id]['seg_strand'] = '.'

        for modname, meth_data in meth_result.items():
            no_calls = 0
            meth_calls = 0
            unmeth_calls = 0

            if -1 in meth_data:
                unmeth_calls = meth_data[-1]

            if 0 in meth_data:
                no_calls = meth_data[0]

            if 1 in meth_data:
                meth_calls = meth_data[1]

            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_meth_calls'] = meth_calls
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_unmeth_calls'] = unmeth_calls
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_no_calls'] = no_calls

            if meth_calls+unmeth_calls == 0:
                 meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_methfrac'] = 'NaN'
            else:
                meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_methfrac'] = meth_calls/float(meth_calls+unmeth_calls)

            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_readcount'] = read_count


    header = ['seg_id', 'seg_chrom', 'seg_start', 'seg_end', 'seg_name', 'seg_strand']

    for bn in base_names:
        for m in mod_names:
            for s in stats:
                header.append('%s_%s_%s' % (os.path.basename(bn), m, s))

    out.write('\t'.join(header) + '\n')

    for mseg in meth_segs:
        output = [] 
        for h in header:
            if h not in meth_segs[mseg]:
                logger.warning('no calls for sample %s in segment %s, skipped.' % (h, mseg))
                break

            output.append(str(meth_segs[mseg][h]))
        out.write('\t'.join(output) + '\n')

    out.close()


def segplot(args):
    '''
    strip plots / violin plots of segmented methylation data over categories by sample
    '''

    data = pd.read_csv(args.segmeth, sep='\t', header=0, index_col=0)

    samples = []
    mods = []

    for c in data.columns:
        if c.endswith('_meth_calls'):
            sample_mod = c.replace('_meth_calls', '')
            sample = '_'.join(sample_mod.split('_')[:-1])
            mod = sample_mod.split('_')[-1]            
            samples.append(sample)
            if mod not in mods:
                mods.append(mod)

    logger.info('available samples: %s' % ','.join(samples))
    logger.info('available mods: %s' % ','.join(mods))

    if args.samples:
        user_samples = []
        for s in args.samples.split(','):
            if s not in samples:
                sys.exit('%s not in sample list!' % s)
            user_samples.append(s)

        samples = user_samples

    if args.mods:
        user_mods = []
        for m in args.mods.split(','):
            if m not in mods:
                sys.exit('%s not in mod list!' % m)
            user_mods.append(m)

        mods = user_mods

    samples_mods = [s + '_' + m for s, m in itertools.product(samples, mods)]

    logger.info('sample + mod permutations: %s' % ','.join(samples_mods))

    for s in samples:
        for m in mods:
            sm = s + '_' + m
            data[sm + '_methfrac'] = data[sm + '_meth_calls']/(data[sm + '_meth_calls']+data[sm + '_unmeth_calls'])

    useable = []

    for seg in data.index:
        use_seg = True

        for s in samples:
            for m in mods:
                sm = s + '_' + m

                if (data[sm + '_meth_calls'].loc[seg] + data[sm + '_unmeth_calls'].loc[seg]) < int(args.mincalls):
                    use_seg = False
                    continue

        if use_seg:
            useable.append(seg)

    data = data.loc[useable]

    logger.info('useable sites: %d' % len(useable))

    plot_data = dd(dict)

    order = []

    for seg in data.index:
        for s in samples:
            for m in mods:
                sm = s + '_' + m
                uid = seg + ':' + sm
                plot_data[uid]['sample'] = sm
                plot_data[uid]['mCpG']   = data[sm + '_methfrac'].loc[seg]
                plot_data[uid]['group']  = data['seg_name'].loc[seg]

                if plot_data[uid]['group'] not in order:
                    order.append(plot_data[uid]['group'])

    plot_data = pd.DataFrame.from_dict(plot_data).T
    plot_data = pd.DataFrame(plot_data.to_dict())

    basename = '.'.join(args.segmeth.split('.')[:-1])

    plot_data.to_csv(basename+'.segplot_data.csv')
    logger.info('plot data written to %s.segplot_data.csv' % basename)

    pt_sz = int(args.pointsize)

    if args.categories is not None:
        order = args.categories.split(',')

    if args.violin:
        sns_plot = sns.violinplot(x='group', y='mCpG', data=plot_data, hue='sample', dodge=True, jitter=True, order=order, hue_order=samples_mods, palette=args.palette)
        basename += '.violin'

    else:
        sns_plot = sns.stripplot(x='group', y='mCpG', data=plot_data, hue='sample', dodge=True, jitter=True, size=pt_sz, order=order, hue_order=samples_mods, palette=args.palette)

    sns_plot.set_xlabel("")

    if args.tiltlabel:
        sns_plot.set_xticklabels(sns_plot.get_xticklabels(), rotation=45)

    sns_plot.set_ylim(float(args.ymin),float(args.ymax))

    if args.width is None:
        args.width = 1 + len(order) * len(samples_mods)
        logger.info('auto set --width: %d' % args.width)

    fig = sns_plot.figure
    fig.set_size_inches(float(args.width), float(args.height))

    basename += '.mc%d' % int(args.mincalls)

    if args.svg:
        fig.savefig(basename+'.segplot.svg', bbox_inches='tight')
        logger.info('plot saved to %s.segplot.svg' % basename)
    else:
        fig.savefig(basename+'.segplot.png', bbox_inches='tight')
        logger.info('plot saved to %s.segplot.png' % basename)


def locus(args):
    '''
    plot methylation profile of a region in CpG space
    '''
    # set up

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos= args.interval.split(':')
    elt_start, elt_end = pos.split('-')

    elt_start = int(elt_start.replace(',',''))
    elt_end = int(elt_end.replace(',',''))
    
    data = {}
    user_colours = {}

    with open(args.data) as _:
        for line in _:
            c = line.strip().split()
            if len(c) < 2:
                logger.warning("required fields for -d/--data are: .bam file and methylation .db (generated by methylartist)")
                sys.exit()

            bam, meth_db = c[:2]
            data[bam] = meth_db

            if len(c) == 3:
                user_colours[bam] = c[2]

    # table for plotting

    meth_table = dd(dict)

    sample_order = []

    reads = {}
    orig_bam = {}

    use_mods = None
    if args.mods:
        use_mods = args.mods.split(',')

    phases = {}

    if args.phased:
        for bam in data:
            phased_reads = get_phased_reads(bam, chrom, elt_start, elt_end, tag_untagged=args.include_unphased)
            phases[bam] = list(set([p for p in phased_reads.values() if p]))
            logger.info('phases for bam %s: %s' % (bam, ','.join(phases[bam])))

    for bam, meth_db in data.items():
        mods = sorted(get_modnames(meth_db))
        logger.info('found mods: %s in db %s' % (','.join(mods), meth_db))

        for mod in mods:
            if use_mods:
                if mod not in use_mods:
                    logger.info('skipping %s, not specified in -m/--mods %s' % (mod, args.mods))
                    continue
            else:
                use_mods = mods

            if args.phased:
                for phase in phases[bam]:
                    bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + phase + '.' + mod
                    orig_bam[bamname] = bam
                    reads[bamname] = get_meth_locus(args, bam, meth_db, mod, phase=phase)

                    for name, read in reads[bamname].items():
                        for loc in read.llrs.keys():
                            uuid = str(uuid4())
                            meth_table[uuid]['loc'] = loc
                            meth_table[uuid]['llr'] = read.llrs[loc]
                            meth_table[uuid]['read'] = name
                            meth_table[uuid]['sample'] = bamname
                            meth_table[uuid]['call'] = read.meth_calls[loc]

                            if bamname not in sample_order:
                                sample_order.append(bamname)

            else:
                bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + mod
                orig_bam[bamname] = bam
                reads[bamname] = get_meth_locus(args, bam, meth_db, mod)

                for name, read in reads[bamname].items():
                    for loc in read.llrs.keys():
                        uuid = str(uuid4())
                        meth_table[uuid]['loc'] = loc
                        meth_table[uuid]['llr'] = read.llrs[loc]
                        meth_table[uuid]['read'] = name
                        meth_table[uuid]['sample'] = bamname
                        meth_table[uuid]['call'] = read.meth_calls[loc]

                        if bamname not in sample_order:
                            sample_order.append(bamname)

    meth_table = pd.DataFrame.from_dict(meth_table).T

    if 'loc' not in meth_table:
        sys.exit('%s: insufficient coverage for plot' % args.interval)

    meth_table['loc'] = pd.to_numeric(meth_table['loc'])
    meth_table['llr'] = pd.to_numeric(meth_table['llr'])

    meth_table['orig_loc'] = meth_table['loc']
    meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

    if len(meth_table['orig_loc']) == 0:
        sys.exit('%s: insufficient coverage for plot' % args.interval)        

    # optional mincall filter
    args.mincalls = int(args.mincalls)

    if args.mincalls > 0:
        drop_ix = []

        logger.info('assessing coverage per --mincalls, per-read sites to consider: %d' % len(meth_table['loc']))
        tick = 5000
        for i, new_loc in enumerate(meth_table['loc'], 1):
            for sample in sample_order:
                call_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] == new_loc)])

                if call_count < args.mincalls:
                    drop_ix += list(meth_table.loc[(meth_table['loc'] == new_loc)].index)
            
            if i % tick == 0:
                logger.info('%s %s: processed %d sites' % (args.data, args.interval, i))
        
        drop_ix = list(set(drop_ix))
        logger.info('dropping %d positions' % len(drop_ix))
        meth_table.drop(drop_ix, inplace=True)

        meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

        if len(meth_table['orig_loc']) == 0:
            sys.exit('%s: insufficient coverage for plot' % args.interval)

    # create mod space

    coord_to_cpg = {}
    for orig_loc, new_loc in zip(meth_table['orig_loc'], meth_table['loc']):
        coord_to_cpg[orig_loc] = new_loc

    # calibrate plotting parameters
    logger.info('region size: %d' % (elt_end-elt_start))

    modspace_n = len(list(set(coord_to_cpg.values())))
    logger.info('mod space positions: %d' % modspace_n)

    if args.modspace is not None:
        logger.info('user set --modspace %d, mod_n:ticks: %.3f' % (int(args.modspace), modspace_n/float(args.modspace)))
        args.modspace = int(args.modspace)
    else:
        if modspace_n <= 1200:
            args.modspace = 1
        else:
            args.modspace = round(modspace_n/1200)

        logger.info('auto set --modspace %d' % args.modspace)

    if args.smoothwindowsize is not None:
        logger.info('user set --smoothwindowsize %d, mod_n:smoothwindowsize: %.3f' % (int(args.smoothwindowsize), modspace_n/float(args.smoothwindowsize)))
        args.smoothwindowsize = int(args.smoothwindowsize)

        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1
            logger.info('-s/--smoothwindowsize must be an even integer, adjusted value: %d' % args.smoothwindowsize)

    else:
        args.smoothwindowsize = round(0.0167*modspace_n + 18)

        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1

        logger.info('auto set --smoothwindowsize %d' % args.smoothwindowsize)

    # highlights

    h_start = []
    h_end = []
    h_cpg_start = []
    h_cpg_end = []
    h_colors = []

    if args.highlight:
        h_colors = sns.color_palette(args.highlightpalette, n_colors=len(args.highlight.split(',')))

    if args.highlight:
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_s, h_e = map(int, h.split('-'))
            h_start.append(h_s - elt_start)
            h_end.append(h_e - elt_start)

            h_cpg_start.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_start[-1]))])
            h_cpg_end.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_end[-1]))])

    if args.highlight_bed:
        need_colour = 0

        with open(args.highlight_bed) as h_bed:
            for line in h_bed:
                c = line.strip().split()
                assert len(c) >= 3, 'malformed line in --highlight_bed: %s' % line.strip()
                assert c[0] == chrom, 'all entries in --highlight_bed must be on the same chromosome as -i/--interval'

                h_s, h_e = map(int, c[1:3])

                colour = None

                if len(c) > 3:
                    colour = c[3]
                else:
                    need_colour += 1
                
                h_start.append(h_s - elt_start)
                h_end.append(h_e - elt_start)

                h_cpg_start.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_start[-1]))])
                h_cpg_end.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_end[-1]))])

                h_colors.append(colour)

        more_colours = sns.color_palette(args.highlightpalette, n_colors=need_colour)

        for i, c in enumerate(h_colors):
            if c is None:
                h_colors[i] = more_colours.pop()

    # mask

    readmask = []
    if args.readmask:
        for ivl in args.readmask.split(','):
            if ':' in ivl:
                ivl = ivl.split(':')[1]
            assert '-' in ivl, 'malformed --readmask interval(s): %s' % args.readmask

            readmask.append(list(map(int, ivl.split('-'))))

    # set up plot

    fig = plt.figure()
    gs = gridspec.GridSpec(5,1,height_ratios=[1,5,1,3,3], hspace=0)

    img_w = 16
    img_h = 8

    if args.width:
        img_w = float(args.width)

    if args.height:
        img_h = float(args.height)

    fig.set_size_inches(img_w, img_h)

    if args.panelratios:
        p_ratios = list(map(int, args.panelratios.split(',')))
        gs = gridspec.GridSpec(5,1,height_ratios=p_ratios, hspace=0)

    sample_color = {}
    for i, sample in enumerate(sample_order):
        sample_color[sample] = sns.color_palette(args.samplepalette, n_colors=len(sample_order))[i]

    for sample in sample_color:
        if orig_bam[sample] in user_colours:
            sample_color[sample] = user_colours[orig_bam[sample]]

    # plot genes

    ax0 = plt.subplot(gs[0])

    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.xaxis.set_ticks_position('top')

    gtf = None
    genes = [] 

    if args.gtf is not None:
        logger.info('building genes plot...')
        gtf = pysam.Tabixfile(args.gtf)
        genes = build_genes(gtf, chrom, elt_start, elt_end)

    exon_patches = []
    tx_lines = []

    genes_of_interest = []

    if args.genes is not None:
        genes_of_interest = args.genes.strip().split(',')

    genemap = dd(Intersecter)

    if genes_of_interest:
        new_genes = {}
        for ensg in genes:
            if genes[ensg].name in genes_of_interest:
                new_genes[ensg] = genes[ensg]
        
        genes = new_genes

    gene_colours = sns.color_palette(args.genepalette, n_colors=len(genes))

    for i, ensg in enumerate(genes):
        if not genes[ensg].has_tx():
            logger.warning('no transcript for gene: %s' % genes[ensg].name)
            continue

        logger.info('gene in region: %s' % genes[ensg].name)

        y = 1

        while genemap[y].find(genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start):
            y += 1

        tx_lines.append(matplotlib.lines.Line2D([genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start], [0.4+y, 0.4+y], color=gene_colours[i], zorder=2))

        genes[ensg].merge_exons()
        for exon_start, exon_end in genes[ensg].exons:
            exon_len = exon_end - exon_start
            exon_patches.append(matplotlib.patches.Rectangle([exon_start-elt_start, y], exon_len, 0.8, edgecolor=gene_colours[i], facecolor=gene_colours[i], zorder=3))

        genemap[y].add_interval(Interval(genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start))

        if args.labelgenes:
            lg_x  = max(genes[ensg].tx_start-elt_start, 0)
            gtxt  = ax0.text(lg_x, y+0.8, genes[ensg].name, bbox=dict(boxstyle="round,pad=0.3", fc="lavender", alpha=0.5, lw=0), zorder=4)
            bb_w  = gtxt.get_tightbbox(renderer=fig.canvas.get_renderer()).width
            fig_w = fig.get_size_inches()[0]*fig.dpi
            txt_w = bb_w/fig_w*(elt_end-elt_start)
            gtxt.set_x(lg_x-txt_w/2)

    gene_height = max(3, len(genemap))

    if args.labelgenes:
        gene_height = max(6, len(genemap))

    ax0.set_ylim(0, gene_height)
    ax0.set_yticks([])

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_start):
                h_e = h_end[h]
                h_color = h_colors[h]
                ymin = min(ax0.get_ybound())
                yheight = max(ax0.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax0.add_patch(h)

    for p in exon_patches:
        ax0.add_patch(p)

    for tx in tx_lines:
        ax0.add_line(tx)

    # per-read plot

    logger.info('building read alignment plot...')
    ax1 = plt.subplot(gs[1])
    ax1.set_xticks([])
    ax1.set_yticks([])

    readstack = dd(list)

    max_y  = 1
    pack_y = 1

    masked_count = 0

    for bamname in reads:
        fetch_reads_bam = pysam.AlignmentFile(orig_bam[bamname])
        pos_cache = {}

        for read in fetch_reads_bam.fetch(chrom, elt_start, elt_end):
            if read.mapq < 10:
                continue

            if read.is_supplementary or read.is_secondary or read.is_duplicate:
                continue

            masked = False
            if len(readmask) > 0:
                for mask_start, mask_end in readmask:
                    if read.reference_start >= mask_start and read.reference_end <= mask_end:
                        logger.debug('masked read: %s' % read.query_name)
                        masked_count += 1
                        masked = True
            
            if masked:
                continue

            if read.query_name not in pos_cache:
                pos_cache[read.query_name] = ([read.reference_start], [read.reference_end])
            else:
                pos_cache[read.query_name][0].append(read.reference_start)
                pos_cache[read.query_name][1].append(read.reference_end)

        for readname, read in reads[bamname].items():
            if readname not in pos_cache:
                logger.debug('read %s not found in %s (skipped)' % (readname, bamname))
                continue

            if read.call_count == 0:
                continue

            read.ypos = max_y
            max_y += 1

            read.starts, read.ends = pos_cache[readname]
            readstack[bamname].append(read)

        fetch_reads_bam.close()

    if args.readmask:
        logger.info('masked %d reads due to --readmask' % masked_count)

    # read packing

    for bamname in readstack:
        reads = readstack[bamname]

        y = dd(list)

        for read in readstack[bamname]:
            y[read.ypos].append(read)

        for p in y:
            for q in y:
                if p == q:
                    continue

                for read_q in y[q]:
                        move = True

                        for read_p in y[p]:
                            if read_q.overlap(read_p):
                                move = False

                        if move:
                            y[p].append(read_q)
                            y[q].remove(read_q)

        for p in y:
            if len(y[p]) > 0:
                for read in sorted(y[p], key=lambda r: min(r.starts)):
                    read.ypos = pack_y
                pack_y += 1

    ax1.set_ylim(0,pack_y+1)

    for bamname in readstack:
        for read in readstack[bamname]:
            for call_pos, call in read.meth_calls.items():
                if call == -1:
                    ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec=sample_color[bamname], mfc='white', markersize=2, zorder=3)

                if call == 1:
                    ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec='black', mfc='black', markersize=2, zorder=3)

            for i in range(len(read.starts)):
                readline_start = max(read.starts[i], elt_start) - elt_start
                readline_end   = min(read.ends[i], elt_end) - elt_start

                ax1.add_line(matplotlib.lines.Line2D([readline_start, readline_end], [read.ypos-0.1, read.ypos-0.1], zorder=2, color=sample_color[bamname], alpha=0.4))

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_start):
                h_e = h_end[h]
                h_color = h_colors[h]
                ymin = min(ax1.get_ybound())
                yheight = max(ax1.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax1.add_patch(h)

    # plot correspondence between genome space and cpg space

    logger.info('building mod-space plot...')
    ax2 = plt.subplot(gs[2])
    ax3 = ax2.twiny()

    ax2.set_ylim(0,10)
    ax2.set_yticklabels([])

    x1 = []
    x2 = []

    step = int(args.modspace)

    orig_positions = list(set(meth_table['orig_loc']))

    for i, x in enumerate(orig_positions):
        if i in (0, len(orig_positions)-1):
            x2.append(x)
            x1.append(coord_to_cpg[x])

        elif i % step == 0:
            x2.append(x)
            x1.append(coord_to_cpg[x])

    ax2.vlines(x1, 0, 1, color='#777777', zorder=1)
    ax3.vlines(x2, 9, 10, color='#777777', zorder=1)

    if args.highlight or args.highlight_bed:
        for i in range(len(h_start)):
            orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), h_end[i]-h_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=2)
            cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), h_cpg_end[i]-h_cpg_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=3)

            ax3.add_patch(orig_highlight_box)
            ax2.add_patch(cpg_highlight_box)

    for x1_x, x2_x in zip(x1, x2):
        link_end1 = (x1_x, 1)
        link_end2 = (x2_x, 9)

        l_col = '#777777'

        for i in range(len(h_start)):
            if x2_x >= h_start[i] and x2_x <= h_end[i]:
                l_col = h_colors[i]

        con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=l_col)
        ax3.add_artist(con)

    ax0.set_xlim(ax3.get_xlim()) # sync axes between orig coords and gtf plot
    ax1.set_xlim(ax3.get_xlim())
    ax2.set_xticks([])
    ax3.set_xticks([])

    n_ticks = 10
    tick_interval = (elt_end-elt_start)/n_ticks
    tick_list = list(range(int(ax0.get_xlim()[0]), int(ax0.get_xlim()[1]), int(tick_interval)))

    xt_labels = [str(int(t+elt_start)) for t in tick_list]
    ax0.set_xticks(tick_list) 
    ax0.set_xticklabels(xt_labels)

    # llr plot

    logger.info('building llr plot...')
    ax4 = plt.subplot(gs[3])
    ax4.set_xticks([])
    ax4.yaxis.tick_right()

    ax4.axhline(y=0, c='#bbbbbb', linestyle='--',lw=1)

    for mod in use_mods: 
        upper, lower = get_cutoffs(list(data.values())[0], mod)
        ax4.axhline(y=upper, c='k', linestyle='--',lw=1)
        ax4.axhline(y=lower, c='k', linestyle='--',lw=1)

    ax4 = sns.lineplot(x='loc', y='llr', hue='sample', data=meth_table, palette=sample_color, zorder=2)
    ax4.set_xlim(ax2.get_xlim())

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_cpg_start):
                h_e = h_cpg_end[h]
                h_color = h_colors[h]
                ymin = min(ax4.get_ybound())
                yheight = max(ax4.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax4.add_patch(h)

    # meth frac plot

    logger.info('building meth frac plot...')
    ax5 = plt.subplot(gs[4])

    order_stack = 2

    for sample in sample_order:
        windowed_methfrac, meth_n = slide_window(meth_table, sample, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))

        smoothed_methfrac = smooth(np.asarray(list(windowed_methfrac.values())), window_len=int(args.smoothwindowsize), window=args.smoothfunc)

        masked_segs = mask_methfrac(list(meth_n.values()), cutoff=args.maskcutoff)

        ax5.plot(list(windowed_methfrac.keys()), smoothed_methfrac, marker='', lw=4, color=sample_color[sample])

        order_stack += 1

        for seg in masked_segs:
            if len(seg) > 2:
                mf_seg = np.asarray(smoothed_methfrac)[seg]
                pos_seg = np.asarray(list(windowed_methfrac.keys()))[seg]
            
                ax5.plot(pos_seg, mf_seg, marker='', lw=4, color='#ffffff', alpha=0.8, zorder=order_stack)

                order_stack += 1

    ax5.set_xlim(ax2.get_xlim())
    ax5.set_ylim((-0.05,1.05))

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_cpg_start):
                h_e = h_cpg_end[h]
                h_color = h_colors[h]
                ymin = min(ax5.get_ybound())
                yheight = max(ax5.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax5.add_patch(h)

    fn_prefix = '.'.join(args.data.split('.')[:-1]) + '.' + '_'.join(args.interval.split(':')[:2]) + '.' + ''.join(use_mods)

    if args.genes is not None:
        fn_prefix = '_'.join(args.genes.split(',')) + '.' + fn_prefix

    if args.phased:
        fn_prefix += '.phased'

    param_str = '.ms%d.smw%d' % (args.modspace, int(args.smoothwindowsize))

    fn_prefix += param_str

    if args.max_read_density:
        fn_prefix += '.mrd%.2f' % float(args.max_read_density)

    if args.mincalls > 0:
        fn_prefix += '.mc%d' % args.mincalls
    
    if args.svg:
        plt.savefig('%s.locus.meth.svg' % fn_prefix, bbox_inches='tight')
        logger.info('plot saved to %s.locus.meth.svg' % fn_prefix)

    else:
        plt.savefig('%s.locus.meth.png' % fn_prefix, bbox_inches='tight')
        logger.info('plot saved to %s.locus.meth.png' % fn_prefix)


def region(args):
    '''
    plotting function for windowed view of larger regions
    '''

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos = args.interval.split(':')
    start, end = pos.split('-')

    start = int(start.replace(',',''))
    end = int(end.replace(',',''))

    assert start < end

    if end-start < 1000000:
        logger.warning('locus smaller than 1Mbp, "methylartist locus" may yield better results')

    ref = pysam.FastaFile(args.ref)
    motif = args.norm_motif.upper()
    region_seq = ref.fetch(chrom, start, end).upper()
    motif_count = region_seq.count(motif) + region_seq.count(rc(motif))

    if args.windows is None:
        args.windows = round(motif_count / 30)
        logger.info('set window count to %d' % args.windows)

    args.windows = int(args.windows)

    if args.windows < 500:
        args.windows = 500
        logger.info('resetting windows to a minimum of 500')

    motifs_per_window = int(motif_count / int(args.windows))

    if motifs_per_window == 0:
        motifs_per_window = 1

    logger.info('motif count: %d, per window: %d' % (motif_count, motifs_per_window))

    w_starts = [start]
    w_ends = []

    i = 0
    motif_count = 0

    while i < len(region_seq) - len(motif):
        site_fwd = region_seq[i:i+len(motif)]
        site_rev = rc(site_fwd)

        if motif in (site_fwd, site_rev):
            motif_count += 1

        i += 1

        if motif_count == motifs_per_window:
            w_starts.append(start+i)
            w_ends.append(start+i)
            motif_count = 0

    w_ends.append(end)

    assert len(w_starts) == len(w_ends)
    logger.info('using %d windows normalised for %s content' % (len(w_starts), args.norm_motif))

    if args.smoothwindowsize is None:
        w = len(w_starts)
        args.smoothwindowsize = round(0.02*w + 4)
        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1
            
        logger.info('set --smoothwindowsize to %d' % args.smoothwindowsize)

    args.smoothwindowsize = int(args.smoothwindowsize)

    if args.modspace is None:
        args.modspace = round(len(w_starts)/300)
        if args.modspace == 0:
            args.modspace = 1

        logger.info('set --modspace to %d' % args.modspace)

    args.modspace = int(args.modspace)

    data = {}
    mods = []
    user_colours = {}

    with open(args.data) as _:
        for line in _:
            c = line.strip().split()
            if len(c) < 2:
                logger.warning("required fields for -d/--data are: .bam file and methylation .db (generated by methylartist)")
                sys.exit()

            bam, meth_db = c[:2]
            data[bam] = meth_db
            mods += sorted(get_modnames(meth_db))

            if len(c) == 3:
                user_colours[bam] = c[2]

    mods = list(set(mods))

    logger.info('found mods: %s' % ','.join(mods))

    if args.mods:
        for mod in args.mods.split(','):
            assert mod in mods, 'mod %s not found' % mod
        mods = args.mods.split(',')

    logger.info('using mods %s' % ','.join(mods))

    reads = {}
    orig_bam = {}

    for bam, meth_db in data.items():
        for mod in mods:
            bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + mod
            orig_bam[bamname] = bam
            reads[bamname] = get_meth_locus(args, bam, meth_db, mod)

    pool = mp.Pool(processes=int(args.procs))

    results = []

    for seg_start, seg_end in zip(w_starts, w_ends):
        for bam_fn, meth_fn in data.items():
            seg_strand = '.'
            seg_name = '.'.join(os.path.basename(bam_fn).split('.')[:-1])
            res = pool.apply_async(get_segmeth_calls, [args, bam_fn, mods, meth_fn, chrom, seg_start, seg_end, seg_name, seg_strand])
            results.append(res)

    meth_segs = dd(dict)

    sample_names = {}

    logger.info('parsing segments...')

    shallow_windows = dd(list)
    min_window_calls = int(args.min_window_calls)

    for res in results:
        meth_result, seg = res.get()
        if meth_result is None:
            continue

        seg_id = '%s:%d-%d' % seg[:3]

        seg_chrom, seg_start, seg_end, seg_name, seg_strand = map(str, seg[:-1])

        meth_segs[seg_id]['seg_id']     = seg_id
        meth_segs[seg_id]['seg_chrom']  = seg_chrom
        meth_segs[seg_id]['seg_start']  = seg_start
        meth_segs[seg_id]['seg_end']    = seg_end
        meth_segs[seg_id]['seg_name']   = seg_name
        meth_segs[seg_id]['seg_strand'] = seg_strand

        for modname, meth_data in meth_result.items():
            no_calls = 0
            meth_calls = 0
            unmeth_calls = 0

            if -1 in meth_data:
                unmeth_calls = meth_data[-1]

            if 0 in meth_data:
                no_calls = meth_data[0]

            if 1 in meth_data:
                meth_calls = meth_data[1]

            if meth_calls + unmeth_calls < min_window_calls:
                if seg_id not in shallow_windows[modname]:
                    shallow_windows[modname].append(seg_id)

            sample_name = seg_name + '.' + modname
            sample_names[sample_name] = True

            meth_segs[seg_id][sample_name + '_meth_calls'] = meth_calls
            meth_segs[seg_id][sample_name + '_unmeth_calls'] = unmeth_calls

            if meth_calls + unmeth_calls == 0:
                 meth_segs[seg_id][sample_name + '_frac'] = 0 # changed from NaN
            else:
                meth_segs[seg_id][sample_name + '_frac'] = meth_calls/float(meth_calls+unmeth_calls)

    for mod in mods:
        shallow_frac = len(shallow_windows[mod])/len(w_starts)*100.0
        if shallow_frac > 0.0:
            logger.info('%.2f percent of windows for mod %s had less than %d calls' % (shallow_frac, mod, min_window_calls))
            if shallow_frac > float(args.maxuncovered):
                sys.exit('greater than %.2f windows are uncovered, aborting.' % float(args.maxuncovered))

    deleted_segs = 0
    for mod in mods:
        for seg_id in shallow_windows[mod]:
            if seg_id in meth_segs:
                del meth_segs[seg_id]
                deleted_segs += 1

    if deleted_segs > 0:
        logger.info('removed %d segs with less than %d calls in at least one mod' % (deleted_segs, min_window_calls))

    meth_segs = pd.DataFrame.from_dict(meth_segs).T

    if 'seg_start' not in meth_segs:
        logger.warning('no methylation calls.')
        sys.exit()

    meth_segs['seg_start'] = pd.to_numeric(meth_segs['seg_start'])

    meth_segs.sort_values('seg_start', inplace=True)
    meth_segs['pos'] = np.arange(len(meth_segs.index))

    for sample in sample_names:
        meth_segs[sample] = smooth(np.asarray(meth_segs[sample + '_frac']), window_len=int(args.smoothwindowsize), window=args.smoothfunc)
        meth_segs[sample] = meth_segs[sample].rolling(window=10, min_periods=1, center=True).mean()

    coord_to_pos = {}
    for orig_loc, new_loc in zip(meth_segs['seg_start'], meth_segs['pos']):
        coord_to_pos[orig_loc] = new_loc

    # highlights

    h_start = []
    h_end = []
    h_cpg_start = []
    h_cpg_end = []
    h_colors = []

    if args.highlight:
        h_colors = sns.color_palette(args.highlightpalette, n_colors=len(args.highlight.split(',')))

    if args.highlight:
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_s, h_e = map(int, h.split('-'))
            h_start.append(h_s)
            h_end.append(h_e)

            h_cpg_start.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_start[-1]))])
            h_cpg_end.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_end[-1]))])

    if args.highlight_bed:
        need_colour = 0

        with open(args.highlight_bed) as h_bed:
            for line in h_bed:
                c = line.strip().split()
                assert len(c) >= 3, 'malformed line in --highlight_bed: %s' % line.strip()
                assert c[0] == chrom, 'all entries in --highlight_bed must be on the same chromosome as -i/--interval'

                h_s, h_e = map(int, c[1:3])

                colour = None

                if len(c) > 3:
                    colour = c[3]
                else:
                    need_colour += 1
                
                h_start.append(h_s)
                h_end.append(h_e)

                h_cpg_start.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_start[-1]))])
                h_cpg_end.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_end[-1]))])

                h_colors.append(colour)

        more_colours = sns.color_palette(args.highlightpalette, n_colors=need_colour)

        for i, c in enumerate(h_colors):
            if c is None:
                h_colors[i] = more_colours.pop()

    # mask

    readmask = []
    if args.readmask:
        for ivl in args.readmask.split(','):
            if ':' in ivl:
                ivl = ivl.split(':')[1]
            assert '-' in ivl, 'malformed --readmask interval(s): %s' % args.readmask

            readmask.append(list(map(int, ivl.split('-'))))

    # set up plot

    fig = plt.figure()
    gs = gridspec.GridSpec(4,1,height_ratios=[1,5,1,3], hspace=0)

    img_w = 16
    img_h = 8

    if args.width:
        img_w = float(args.width)

    if args.height:
        img_h = float(args.height)

    fig.set_size_inches(img_w, img_h)

    if args.panelratios:
        p_ratios = list(map(int, args.panelratios.split(',')))
        gs = gridspec.GridSpec(4,1,height_ratios=p_ratios, hspace=0)

    sample_order = sorted(sample_names.keys())

    logger.info('sample order: %s' % ','.join(sample_order))

    sample_color = {}
    for i, sample in enumerate(sample_order):
        sample_color[sample] = sns.color_palette(args.samplepalette, n_colors=len(sample_order))[i]

    for sample in sample_color:
        if orig_bam[sample] in user_colours:
            sample_color[sample] = user_colours[orig_bam[sample]]

    # plot genes

    ax0 = plt.subplot(gs[0])

    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.xaxis.set_ticks_position('top')

    gtf = None
    genes = [] 

    if args.gtf is not None:
        logger.info('building genes plot...')
        gtf = pysam.Tabixfile(args.gtf)
        genes = build_genes(gtf, chrom, start, end)

    exon_patches = []
    tx_lines = []

    genes_of_interest = []

    if args.genes is not None:
        genes_of_interest = args.genes.strip().split(',')

    genemap = dd(Intersecter)

    if genes_of_interest:
        new_genes = {}
        for ensg in genes:
            if genes[ensg].name in genes_of_interest:
                new_genes[ensg] = genes[ensg]
        
        genes = new_genes

    if args.gtf:
        logger.info('%d genes in region' % len(genes))

    gene_colours = sns.color_palette(args.genepalette, n_colors=len(genes))

    for i, ensg in enumerate(genes):
        if not genes[ensg].has_tx():
            continue

        y = 1

        while genemap[y].find(genes[ensg].tx_start, genes[ensg].tx_end):
            y += 1

        tx_lines.append(matplotlib.lines.Line2D([genes[ensg].tx_start, genes[ensg].tx_end], [0.4+y, 0.4+y], color=gene_colours[i], zorder=2))

        genes[ensg].merge_exons()
        for exon_start, exon_end in genes[ensg].exons:
            exon_len = exon_end - exon_start
            exon_patches.append(matplotlib.patches.Rectangle([exon_start, y], exon_len, 0.8, edgecolor=gene_colours[i], facecolor=gene_colours[i], zorder=3))

        genemap[y].add_interval(Interval(genes[ensg].tx_start, genes[ensg].tx_end))

        if args.labelgenes:
            lg_x  = max(genes[ensg].tx_start, start)
            gtxt  = ax0.text(lg_x, y+0.8, genes[ensg].name, bbox=dict(boxstyle="round,pad=0.3", fc="lavender", alpha=0.5, lw=0), zorder=4)
            bb_w  = gtxt.get_tightbbox(renderer=fig.canvas.get_renderer()).width
            fig_w = fig.get_size_inches()[0]*fig.dpi
            txt_w = bb_w/fig_w*(end-start)
            gtxt.set_x(lg_x-txt_w/2)

    gene_height = max(3, len(genemap))

    if args.labelgenes:
        gene_height = max(6, len(genemap))

    ax0.set_ylim(0, gene_height)
    ax0.set_yticks([])

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_start):
                h_e = h_end[h]
                h_color = h_colors[h]
                ymin = min(ax0.get_ybound())
                yheight = max(ax0.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax0.add_patch(h)
        
    for p in exon_patches:
        ax0.add_patch(p)

    for tx in tx_lines:
        ax0.add_line(tx)

    # per-read plot

    ax1 = plt.subplot(gs[1])
    ax1.set_xticks([])
    ax1.set_yticks([])

    if args.skip_align_plot:
        logger.info('skipped alignment plot due to --skip_align_plot')

    else:
        logger.info('building read alignment plot...')

        readstack = dd(list)

        max_y  = 1
        pack_y = 1

        masked_count = 0

        for bamname in reads:
            fetch_reads_bam = pysam.AlignmentFile(orig_bam[bamname])
            pos_cache = {}

            for read in fetch_reads_bam.fetch(chrom, start, end):
                if read.mapq < 10:
                    continue

                if read.is_supplementary or read.is_secondary or read.is_duplicate:
                    continue

                masked = False
                if len(readmask) > 0:
                    for mask_start, mask_end in readmask:
                        if read.reference_start >= mask_start and read.reference_end <= mask_end:
                            logger.debug('masked read: %s' % read.query_name)
                            masked_count += 1
                            masked = True
                
                if masked:
                    continue

                if read.query_name not in pos_cache:
                    pos_cache[read.query_name] = ([read.reference_start], [read.reference_end])
                else:
                    pos_cache[read.query_name][0].append(read.reference_start)
                    pos_cache[read.query_name][1].append(read.reference_end)

            for readname, read in reads[bamname].items():
                if readname not in pos_cache:
                    logger.debug('read %s not found in %s (skipped)' % (readname, bamname))
                    continue

                if read.call_count == 0:
                    continue

                read.ypos = max_y
                max_y += 1

                read.starts, read.ends = pos_cache[readname]
                readstack[bamname].append(read)

            fetch_reads_bam.close()

        if args.readmask:
            logger.info('masked %d reads due to --readmask' % masked_count)

        # read packing

        for bamname in readstack:
            reads = readstack[bamname]

            y = dd(list)

            for read in readstack[bamname]:
                y[read.ypos].append(read)

            for p in y:
                for q in y:
                    if p == q:
                        continue

                    for read_q in y[q]:
                            move = True

                            for read_p in y[p]:
                                if read_q.overlap(read_p):
                                    move = False

                            if move:
                                y[p].append(read_q)
                                y[q].remove(read_q)

            for p in y:
                if len(y[p]) > 0:
                    for read in sorted(y[p], key=lambda r: min(r.starts)):
                        read.ypos = pack_y
                    pack_y += 1

        ax1.set_ylim(0,pack_y+1)

        for bamname in readstack:
            for read in readstack[bamname]:
                for call_pos, call in read.meth_calls.items():
                    call_pos += start

                    if call == -1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec=sample_color[bamname], mfc='white', markersize=2, zorder=3)

                    if call == 1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec='black', mfc='black', markersize=2, zorder=3)

                for i in range(len(read.starts)):
                    readline_start = max(read.starts[i], start)
                    readline_end   = min(read.ends[i], end)

                    ax1.add_line(matplotlib.lines.Line2D([readline_start, readline_end], [read.ypos-0.1, read.ypos-0.1], zorder=2, color=sample_color[bamname], alpha=0.4))

        highlight_patches = []

        if args.highlight_panels:
            a = float(args.highlight_panels)
            if args.highlight or args.highlight_bed:
                for h, h_s in enumerate(h_start):
                    h_e = h_end[h]
                    h_color = h_colors[h]
                    ymin = min(ax1.get_ybound())
                    yheight = max(ax1.get_ybound())-ymin
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

        for h in highlight_patches:
            ax1.add_patch(h)

    # plot correspondence between genome space and cpg space

    logger.info('building mod-space plot...')
    ax2 = plt.subplot(gs[2])
    ax3 = ax2.twiny()

    ax2.set_ylim(0,10)
    ax2.set_yticklabels([])

    x1 = []
    x2 = []

    step = int(args.modspace)

    for i, x in enumerate(meth_segs['seg_start']):
        if i in (0, len(meth_segs['seg_start'])-1):
            x2.append(x)
            x1.append(coord_to_pos[x])

        elif i % step == 0:
            x2.append(x)
            x1.append(coord_to_pos[x])

    
    ax2.vlines(x1, 0, 1, color='#777777', zorder=1)
    ax3.vlines(x2, 9, 10, color='#777777', zorder=1)

    if args.highlight or args.highlight_bed:
        for i in range(len(h_start)):
            orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), h_end[i]-h_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=2)
            cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), h_cpg_end[i]-h_cpg_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=3)

            ax3.add_patch(orig_highlight_box)
            ax2.add_patch(cpg_highlight_box)

    for x1_x, x2_x in zip(x1, x2):
        link_end1 = (x1_x, 1)
        link_end2 = (x2_x, 9)

        l_col = '#777777'

        for i in range(len(h_start)):
            if x2_x >= h_start[i] and x2_x <= h_end[i]:
                l_col = h_colors[i]

        con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=l_col)
        ax3.add_artist(con)

    ax0.set_xlim(ax3.get_xlim()) # sync axes between orig coords and gtf plot
    ax1.set_xlim(ax3.get_xlim())
    ax2.set_xticks([])
    ax3.set_xticks([])

    n_ticks = 10
    tick_interval = (end-start)/n_ticks
    tick_list = list(range(int(ax0.get_xlim()[0]), int(ax0.get_xlim()[1]), int(tick_interval)))

    xt_labels = [str(int(t)) for t in tick_list]
    ax0.set_xticks(tick_list) 
    ax0.set_xticklabels(xt_labels)

    # meth frac plot

    logger.info('building meth frac plot...')
    ax4 = plt.subplot(gs[3])

    for sample in sample_order:
        ax4.plot(meth_segs['pos'], meth_segs[sample], marker='', lw=4, color=sample_color[sample], zorder=2, label=sample)

    ax4.legend()
    ax4.set_xlim(ax2.get_xlim())
    ax4.set_ylim((-0.05,1.05))

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_cpg_start):
                h_e = h_cpg_end[h]
                h_color = h_colors[h]
                ymin = min(ax4.get_ybound())
                yheight = max(ax4.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax4.add_patch(h)

    # output
    fn_prefix = '.'.join(args.data.split('.')[:-1]) + '.' + '_'.join(args.interval.split(':')[:2]) + '.' + ''.join(mods)

    if args.genes is not None:
        fn_prefix = '_'.join(args.genes.split(',')) + '.' + fn_prefix

    fn_prefix += '.s%d.w%d.m%d' % (args.smoothwindowsize, args.windows, args.modspace)

    if args.svg:
        plt.savefig('%s.region.meth.svg' % fn_prefix, bbox_inches='tight')
        logger.info('plot saved to %s.region.meth.svg' % fn_prefix)

    else:
        plt.savefig('%s.region.meth.png' % fn_prefix, bbox_inches='tight')
        logger.info('plot saved to %s.region.meth.png' % fn_prefix)


def composite(args):
    '''
    plot composite methylation profiles relative to a consensus element
    '''
    te_ref_seq = single_seq_fa(args.teref).upper()

    assert os.path.exists(args.ref + '.fai'), 'ref fasta must be indexed'

    fams = args.fams.split(',')

    logger.info('fams: %s' % ','.join(fams))

    segdata = pd.read_csv(args.segdata, sep='\t', header=0, index_col=0)

    sample = args.sample

    avail_samples = ['_'.join(col.split('_')[:-2]) for col in segdata if col.endswith('_meth_calls')]

    if sample not in avail_samples:
        sys.exit('%s not in available samples: %s' % (sample, ','.join(avail_samples)))

    avail_mods = get_modnames(args.meth)

    logger.info('available mods: %s' % ','.join(avail_mods))

    use_mod = None

    if len(avail_mods) == 1:
        use_mod = avail_mods[0]
        logger.info('using the one available mod: %s' % use_mod)

    if use_mod is None:
        if '_' not in sample:
            sys.exit('cannot guess which mod to use, pick with --mod from %s and re-run' % ','.join(avail_mods))
        
        use_mod = sample.split('_')[-1]

        if use_mod not in avail_mods:
            sys.exit('inferred mod %s not available, pick with --mod from %s and re-run' % (use_mod, ','.join(avail_mods)))

    useable = []

    for seg in segdata.index:
        use_seg = True

        if segdata['seg_name'].loc[seg] not in fams:
            use_seg = False
            continue

        if (segdata[sample+'_meth_calls'].loc[seg] + segdata[sample+'_unmeth_calls'].loc[seg]) < int(args.mincalls):
            use_seg = False
            continue

        if use_seg:
            useable.append(seg)

    if len(useable) > int(args.maxsegs):
        useable = random.sample(useable, int(args.maxsegs))

    segdata = segdata.loc[useable]

    logger.info('useable segments: %d' % len(useable))


    pool = mp.Pool(processes=int(args.procs))

    results = []

    for seg in segdata.index:
        seg_chrom  = str(segdata['seg_chrom'].loc[seg])
        seg_start  = int(segdata['seg_start'].loc[seg])
        seg_end    = int(segdata['seg_end'].loc[seg])
        seg_strand = str(segdata['seg_strand'].loc[seg])
        seg_name   = str(segdata['seg_name'].loc[seg])

        res = pool.apply_async(get_meth_profile_composite, [args, seg_chrom, seg_start, seg_end, seg_name, seg_strand, use_mod])

        results.append(res)

    # set bounds
    mod_start = 0
    mod_end = len(te_ref_seq)

    if args.start:
        mod_start = int(args.start)

    if args.end:
        mod_end = int(args.end)

    assert mod_start < mod_end

    # collect mod data
    out_res = [] # cache for --outelts

    for res in results:
        coord_meth_pos, meth_profile = res.get()

        if len(coord_meth_pos) == 0:
            continue

        out_res.append((coord_meth_pos, meth_profile))


    if len(out_res) < int(args.outelts):
        logger.info('available profiles (%d) is less than --outelts (%d)' % (len(out_res), int(args.outelts)))
        args.outelts = len(out_res)

    # set up plot
    fig = plt.figure()
    gs = None

    if args.plotmean:
        gs = gridspec.GridSpec(3,1,height_ratios=[2,1,8])
    else:
        gs = gridspec.GridSpec(2,1,height_ratios=[1,8])

    g = 0

    # mean

    if args.plotmean:
        ax0 = plt.subplot(gs[g])
        g += 1
        ax0.set_ylim((-0.05,1.05))
        ax0.set_xlim((mod_start, mod_end))
        ax0.set_xticks([])

        meth_by_coord = dd(list)

        per_elt_calls = []

        for coord_meth_pos, meth_profile in out_res:
            per_elt_calls.append(len(coord_meth_pos))
            for c, m in zip(coord_meth_pos, meth_profile):
                meth_by_coord[c].append(m)

        median_call_count = int(np.median(per_elt_calls))

        v = sorted([len(c) for c in meth_by_coord.values()], reverse=True)

        cutoff = v[-1]

        if len(v) > median_call_count:
            cutoff = v[median_call_count]

        logger.info('median per element call count: %d' % median_call_count)
        logger.info('per site call count cutoff: %d' % cutoff)

        meanplot_table = dd(dict)

        for c in sorted(meth_by_coord.keys()):
            if len(meth_by_coord[c]) >= cutoff:
                for m in meth_by_coord[c]:
                    u = str(uuid4())
                    meanplot_table[u]['coord'] = c
                    meanplot_table[u]['meth'] = m
        
        meanplot_table = pd.DataFrame.from_dict(meanplot_table).T
        meanplot_table.to_csv('test.csv')

        ax0 = sns.lineplot(x='coord', y='meth', data=meanplot_table, ci='sd', lw=2, color=args.colour)


    # mod
    ax1 = plt.subplot(gs[g])
    g += 1
    ax1.set_xlim((mod_start, mod_end))

    box = matplotlib.patches.Rectangle([0, 0], mod_end-mod_start, 1.0, edgecolor='#555555', facecolor='#cfcfcf', zorder=1)
    ax1.add_patch(box)

    if args.blocks:
        with open(args.blocks) as blocks:
            for line in blocks:
                b_start, b_end, b_name, b_col = line.strip().split()
                b_start = int(b_start)
                b_end = int(b_end)

                box = matplotlib.patches.Rectangle([b_start, 0], b_end-b_start, 1.0, edgecolor='#555555', facecolor=b_col, zorder=2)
                ax1.add_patch(box)

    mod_locs = []

    motif = args.motif

    for i in range(len(te_ref_seq)-len(motif)):
        if i >= mod_start and i <= mod_end:
            if te_ref_seq[i:i+len(motif)] == motif:
                mod_locs.append(i)

    ax1.vlines(mod_locs, 0, 1, lw=1, colors=('#FF4500'), zorder=3, alpha=0.5)

    ax1.spines['bottom'].set_visible(False)
    ax1.spines['left'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.set_xlim((mod_start, mod_end))
    ax1.xaxis.set_ticks_position('top')

    # wiggles
    ax2 = plt.subplot(gs[g])
    g += 1

    for coord_meth_pos, meth_profile in random.sample(out_res, int(args.outelts)):
        ax2.plot(coord_meth_pos, meth_profile, lw=float(args.linewidth), alpha=float(args.alpha), color=args.colour)

    #ax2.set_xlim((0,len(te_ref_seq)))
    ax2.set_ylim((-0.05,1.05))

    ax2.set_xlim((mod_start, mod_end))

    fig.set_size_inches(16, 6)

    fn_base = sample + '.' + '_'.join(fams)

    if args.plotmean:
        fn_base += '.plotmean'

    if args.svg:
        plt.savefig('%s.composite.svg' % fn_base, bbox_inches='tight')
        logger.info('plotted to %s.composite.svg' % fn_base)
    else:
        plt.savefig('%s.composite.png' % fn_base, bbox_inches='tight')
        logger.info('plotted to %s.composite.png' % fn_base)


def wgmeth(args):
    '''
    generates whole genome output in DSS or bedmethyl format
    '''

    meth_table = [None, None]
    seg_meth_table_store = dd(list)

    pool = mp.Pool(processes=int(args.procs))
    bin_size = int(args.binsize)

    mods = sorted(get_modnames(args.methdb))

    if args.mod not in mods:
        logger.warning('mod %s not in known mods for db: %s' % (args.mod, ','.join(mods)))
        sys.exit()

    if len(mods) > 1 and args.mod is None:
        logger.warning('more than one mod exists, need to pick one with -m/--mod: %s' % ','.join(mods))
        sys.exit()

    results = []

    with open(args.fai) as fai:
        for line in fai:
            chrom, chrlen = line.strip().split()[:2]
            chrlen = int(chrlen)

            if args.chrom:
                if args.chrom != chrom:
                    continue

            for seg_start in range(0, chrlen, bin_size):
                seg_end = seg_start + bin_size

                if seg_end > chrlen:
                    seg_end = chrlen

                seg_start = int(seg_start)
                seg_end = int(seg_end)

                res = pool.apply_async(get_meth_calls_wg, [args, args.bam, args.methdb, chrom, seg_start, seg_end, args.phased, args.mod])

                results.append(res)


    for res in results:
        seg_meth_table = res.get()

        if seg_meth_table is None:
            continue

        if args.phased:
            for phase in (0,1):
                seg_meth_table_store[phase].append(pd.DataFrame.from_dict(seg_meth_table[phase]).T)

        else:
            seg_meth_table_store[0].append(pd.DataFrame.from_dict(seg_meth_table[0]).T)

    if args.phased:
        for phase in (0,1):
            meth_table[phase] = pd.concat(seg_meth_table_store[phase])

    else:
        meth_table[0] = pd.concat(seg_meth_table_store[0])

    if args.phased:
        for phase in (0,1):
            meth_table[phase]['pos'] = meth_table[phase].index
            meth_table[phase] = meth_table[phase].sort_values(['chr', 'pos'])

            outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1])

            if args.chrom:
                outfn += '.%s' % args.chrom

            if args.dss:
                outfn += '.%s.phase_%d.DSS.txt' % (str(args.mod), phase)
            else:
                outfn += '.%s.phase_%d.methyl.bed' % (str(args.mod), phase)

            logger.info('writing %s' % outfn)

            if args.dss:
                meth_table[phase].to_csv(outfn, columns=['chr','pos','N','X'], index=False, sep='\t') # 1-based

            else:
                meth_table[phase]['score']  = meth_table[phase]['N']
                meth_table[phase]['score']  = meth_table[phase]['score'].where(meth_table[phase]['score'] <= 1000, 1000) # cap score at 1000
                meth_table[phase]['start']  = meth_table[phase]['pos']-1 # 0-based
                meth_table[phase]['end']    = meth_table[phase]['pos']
                meth_table[phase]['pct']    = meth_table[phase]['X']/meth_table[phase]['N']
                meth_table[phase]['name']   = '.'.join(args.methdb.split('.')[:-1])
                meth_table[phase]['strand'] = '.'
                meth_table[phase]['colour'] = '0,0,0'

                meth_table[phase].to_csv(outfn, columns=['chr','start','end','name', 'score', 'strand', 'start', 'end', 'colour', 'N', 'pct'], header=False, index=False, sep='\t')


    else:
        meth_table[0]['pos'] = meth_table[0].index
        meth_table[0] = meth_table[0].sort_values(['chr', 'pos'])

        outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1])

        if args.chrom:
            outfn += '.%s' % args.chrom

        if args.dss:
            outfn += '%s.DSS.txt' % str(args.mod)
        else:
            outfn += '%s.methyl.bed' % str(args.mod)

        logger.info('writing %s' % outfn)
        if args.dss:
            meth_table[0].to_csv(outfn, columns=['chr','pos','N','X'], index=False, sep='\t')
        else:
            meth_table[0]['score']  = meth_table[0]['N']
            meth_table[0]['score']  = meth_table[0]['score'].where(meth_table[0]['score'] <= 1000, 1000)
            meth_table[0]['start']  = meth_table[0]['pos']-1
            meth_table[0]['end']    = meth_table[0]['pos']
            meth_table[0]['pct']    = meth_table[0]['X']/meth_table[0]['N']*100
            meth_table[0]['name']   = '.'.join(args.methdb.split('.')[:-1])
            meth_table[0]['strand'] = '.'
            meth_table[0]['colour'] = '0,0,0'

            meth_table[0].to_csv(outfn, columns=['chr','start','end','name', 'score', 'strand', 'start', 'end', 'colour', 'N', 'pct'], header=False, index=False, sep='\t')

def main(args):
    logger.info('starting methylartist with command: %s' % ' '.join(sys.argv))
    args.func(args)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='methylartist: tools for exploring nanopore modified base data')
    subparsers = parser.add_subparsers(title="tool", dest="tool")
    subparsers.required = True

    parser_nanopolish    = subparsers.add_parser('db-nanopolish')
    parser_megalodon     = subparsers.add_parser('db-megalodon')
    parser_guppy         = subparsers.add_parser('db-guppy')
    parser_segmeth       = subparsers.add_parser('segmeth')
    parser_segplot       = subparsers.add_parser('segplot')
    parser_locus         = subparsers.add_parser('locus')
    parser_region        = subparsers.add_parser('region')
    parser_composite     = subparsers.add_parser('composite')
    parser_wgmeth        = subparsers.add_parser('wgmeth')
    parser_adjustcutoffs = subparsers.add_parser('adjustcutoffs')
    parser_scoredist     = subparsers.add_parser('scoredist')

    parser_segmeth.set_defaults(func=segmeth)
    parser_segplot.set_defaults(func=segplot)
    parser_locus.set_defaults(func=locus)
    parser_region.set_defaults(func=region)
    parser_composite.set_defaults(func=composite)
    parser_wgmeth.set_defaults(func=wgmeth)
    parser_nanopolish.set_defaults(func=db_nanopolish)
    parser_megalodon.set_defaults(func=db_megalodon)
    parser_guppy.set_defaults(func=db_guppy)
    parser_adjustcutoffs.set_defaults(func=adjustcutoffs)
    parser_scoredist.set_defaults(func=scoredist)

    # options for methylation segments
    parser_segmeth.add_argument('-d', '--data', required=True, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_segmeth.add_argument('-i', '--intervals', required=True, help='.bed file')
    parser_segmeth.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_segmeth.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_segmeth.add_argument('--excl_ambig', action='store_true', default=False)

    # options for methylation strip / violin plots
    parser_segplot.add_argument('-s', '--segmeth', required=True, help='output from segmeth.py')
    parser_segplot.add_argument('-m', '--samples', default=None, help='samples, comma delimited')
    parser_segplot.add_argument('-d', '--mods', default=None, help='mods, comma delimited')
    parser_segplot.add_argument('-c', '--categories', default=None, help='categories, comma delimited, need to match seg_name column from input')
    parser_segplot.add_argument('-v', '--violin', default=False, action='store_true')
    parser_segplot.add_argument('-n', '--mincalls', default=10, help='minimum number of calls to include site (methylated + unmethylated) (default=10)')
    parser_segplot.add_argument('--width', default=None, help='figure width (default = automatic)')
    parser_segplot.add_argument('--height', default=6, help='figure height (default = 6)')
    parser_segplot.add_argument('--pointsize', default=1, help='point size for scatterplot (default = 1)')
    parser_segplot.add_argument('--ymin', default=-0.15, help='ymin (default = -0.15)')
    parser_segplot.add_argument('--ymax', default=1.15, help='ymax (default = 1.15)')
    parser_segplot.add_argument('--tiltlabel', default=False, action='store_true')
    parser_segplot.add_argument('--palette', default="tab10", help='palette for phases (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_segplot.add_argument('--svg', default=False, action='store_true')

    # options for locus-specific plots
    parser_locus.add_argument('-d', '--data', required=True, help='text file with .bam filename and corresponding methylation database per line (whitespace-delimited)')
    parser_locus.add_argument('-i', '--interval', required=True, help='chrom:start-end')
    parser_locus.add_argument('-g', '--gtf', default=None, help='genes or intervals to display in gtf format')
    parser_locus.add_argument('-l', '--highlight', default=None, help='format: start-end, (can be chrom:start-end but chrom is ignored) can comma-delimit multiple highlights')
    parser_locus.add_argument('-b', '--highlight_bed', default=None, help='BED3+1 format (chrom, start, end, optional_colour) where colour (optional) must be intelligible to matplotlib')
    parser_locus.add_argument('-m', '--mods', default=None, help='mods, comma-delimited for >1 (default to all available mods)')
    parser_locus.add_argument('-s', '--smoothwindowsize', default=None, help='size of window for smoothing (default=auto)')
    parser_locus.add_argument('-t', '--slidingwindowstep', default=1, help='step size for initial sliding window (default=1)')
    parser_locus.add_argument('-p', '--panelratios',  default=None, help='Alter panel ratios: needs to be 5 comma-seperated integers. Default: 1,5,1,3,3')
    parser_locus.add_argument('--phased', default=False, action='store_true', help='split samples into phases')
    parser_locus.add_argument('--include_unphased', default=False, action='store_true', help='include an "unphased" category if called with --phased')
    parser_locus.add_argument('--readmask', default=None, help='mask reads from being shown in interval(s) (start-end or chrom:start-end; chrom ignored). Can be comma-delimited.')
    parser_locus.add_argument('--slidingwindowsize', default=2, help='size of initial sliding window for coverage check (default=2)')
    parser_locus.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_locus.add_argument('--maskcutoff', default=4, help='read count masking cutoff (default=4)')
    parser_locus.add_argument('-n', '--mincalls', default=0, help='drop modspace positions if call count (meth+unmeth) < --mincalls (default=0)')
    parser_locus.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_locus.add_argument('--modspace', default=None, help='spacing between links in top panel (default=auto)')
    parser_locus.add_argument('--genes', default=None, help='genes of interest (comma delimited)')
    parser_locus.add_argument('--labelgenes', default=False, action='store_true', help='plot gene names')
    parser_locus.add_argument('--methcall_ymax', default=None)
    parser_locus.add_argument('--methcall_xmin', default=None)
    parser_locus.add_argument('--methcall_xmax', default=None)
    parser_locus.add_argument('--samplepalette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_locus.add_argument('--highlightpalette', default="Blues", help='colour palette name for highlights (default = "Blues")')
    parser_locus.add_argument('--genepalette', default="viridis", help='colour palette name for highlights (default = "viridis")')
    parser_locus.add_argument('--highlight_panels', default=0.25, help='alpha for highlighting in panels (between 0 and 1, default = 0.25)')
    parser_locus.add_argument('--excl_ambig', action='store_true', default=False)
    parser_locus.add_argument('--unambig_highlight', action='store_true', default=False)
    parser_locus.add_argument('--width', default=16, help='image width (inches, default=16)')
    parser_locus.add_argument('--height', default=8, help='image width (inches, default=8)')
    parser_locus.add_argument('--svg', action='store_true')

    # options for region plots
    parser_region.add_argument('-i', '--interval', required=True, help='chrom:start-end')
    parser_region.add_argument('-d', '--data', required=True, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_region.add_argument('-n', '--norm_motif', required=True, help='normalise window sizes to motif occurance')
    parser_region.add_argument('-r', '--ref', required=True, help='ref genome fasta, required if normalising windows with -n/--norm_motif')
    parser_region.add_argument('-g', '--gtf', default=None, help='genes or intervals to display in gtf format')
    parser_region.add_argument('-l', '--highlight', default=None, help='format: start-end, (can be chrom:start-end but chrom is ignored) can comma-delimit multiple highlights')
    parser_region.add_argument('-b', '--highlight_bed', default=None, help='BED3+1 format (chrom, start, end, optional_colour) where colour (optional) must be intelligible to matplotlib')
    parser_region.add_argument('-w', '--windows', default=None, help='set window count, default=auto')
    parser_region.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_region.add_argument('-m', '--mods', default=None, help='mods to consider (comma-delimited, default = all available)')
    parser_region.add_argument('-s', '--smoothwindowsize', default=None, help='size of window for smoothing (default=auto)')
    parser_region.add_argument('--maxuncovered', default=50.0, help='maximum percentage of uncovered windows tolerated (default = 50.0)')
    parser_region.add_argument('--modspace', default=None, help='increase to increase spacing between links in top panel (default=auto)')
    parser_region.add_argument('--readmask', default=None, help='mask reads from being shown in interval(s) (start-end or chrom:start-end; chrom ignored). Can be comma-delimited.')
    parser_region.add_argument('--min_window_calls', default=1, help='minimum reads per window to include in plot (default = 1)')
    parser_region.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_region.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_region.add_argument('--samplepalette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_region.add_argument('--highlightpalette', default="Blues", help='colour palette name for highlights (default = "Blues")')
    parser_region.add_argument('--genepalette', default="viridis", help='colour palette name for highlights (default = "viridis")')
    parser_region.add_argument('--highlight_panels', default=0.25, help='alpha for highlighting in panels (between 0 and 1, default = 0.25)')
    parser_region.add_argument('--panelratios',  default=None, help='Alter panel ratios: needs to be 4 comma-seperated integers. Default: 1,5,3,3')
    parser_region.add_argument('--skip_align_plot', default=False, action='store_true', help='blank alignment plot, useful if unneeded or for runtime.')
    parser_region.add_argument('--genes', default=None, help='genes of interest (comma delimited)')
    parser_region.add_argument('--labelgenes', default=False, action='store_true', help='plot gene names')
    parser_region.add_argument('--width', default=16, help='image width (inches, default=16)')
    parser_region.add_argument('--height', default=8, help='image width (inches, default=8)')
    parser_region.add_argument('--svg', action='store_true', default=False)

    # options for composite plots
    parser_composite.add_argument('-s', '--segdata', required=True, help='segmeth output')
    parser_composite.add_argument('-b', '--bam', required=True)
    parser_composite.add_argument('-m', '--meth', required=True, help='methylation database')
    parser_composite.add_argument('-f', '--fams', required=True, help='families, comma delimited, corresponds to values in seg_name column in segmeth output')
    parser_composite.add_argument('-r', '--ref', required=True, help='ref genome fasta')
    parser_composite.add_argument('-t', '--teref', required=True, help='TE ref fasta')
    parser_composite.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_composite.add_argument('-c', '--colour', default='#ff4500', help='colour (default: #ff4500)')
    parser_composite.add_argument('-a', '--alpha', default=0.3, help='alpha (default: 0.3)')
    parser_composite.add_argument('-w', '--linewidth', default=1, help='line width (default: 1)')
    parser_composite.add_argument('-l', '--lenfrac', default=0.95, help='fraction of TE length that must align (default 0.95)')
    parser_composite.add_argument('--plotmean', default=False, action='store_true', help='plot mean of composite elements at top of plot')
    parser_composite.add_argument('--mod', default=None, help='modification to plot (mod codes will be listed, default: infer from sample name')
    parser_composite.add_argument('--motif', default='CG', help='modified motif to highlight (default = CG)')
    parser_composite.add_argument('--blocks', default=None, help='blocks to highlight (txt file with start, end, name, hex colour)')
    parser_composite.add_argument('--start', default=None, help='start plotting at this base (default None)')
    parser_composite.add_argument('--end', default=None, help='end plotting at this base (default None)')
    parser_composite.add_argument('--sample', required=True, help='specify sample name (default = infer from .bam)')
    parser_composite.add_argument('--mincalls', default=100, help='minimum call count to include elt (default = 100)')
    parser_composite.add_argument('--maxsegs', default=300, help='maximum elements, if > max random.sample() (default = 300)')
    parser_composite.add_argument('--outelts', default=100, help='maximum output elements, if > max random.sample() (default = 300)')
    parser_composite.add_argument('--slidingwindowsize', default=10, help='size of sliding window for meth frac (default 10)')
    parser_composite.add_argument('--slidingwindowstep', default=1, help='step size for meth frac (default 1)')
    parser_composite.add_argument('--smoothwindowsize', default=8, help='size of window for smoothing (default 8)')
    parser_composite.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_composite.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_composite.add_argument('--globalign', action='store_true', default=False, help='experimental')
    parser_composite.add_argument('--excl_ambig', action='store_true', default=False)
    parser_composite.add_argument('--svg', action='store_true', default=False)

    # options for whole genome output
    parser_wgmeth.add_argument('-b', '--bam', required=True, help='bam used for methylation calling')
    parser_wgmeth.add_argument('-d', '--methdb', required=True, help='methylation database')
    parser_wgmeth.add_argument('-s', '--binsize', default=1000000, help='bin size for parallelisation, default = 1000000')
    parser_wgmeth.add_argument('-f', '--fai', required=True, help='fasta index (.fai)')
    parser_wgmeth.add_argument('-m', '--mod', default=None, help='output for specific mod (names vary, see output for hints)')
    parser_wgmeth.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_wgmeth.add_argument('-c', '--chrom', default=None, help='limit analysis to one chromosome')
    parser_wgmeth.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_wgmeth.add_argument('--dss', default=False, action='store_true', help='output in DSS format (default = bedMethyl)')
    parser_wgmeth.add_argument('--phased', action='store_true', default=False, help='split output into phases (currently just 0,1)')

    # options for megalodon db
    parser_megalodon.add_argument('-m', '--methdata', required=True, help='whole genome nanopolish methylation output, can be comma-delimited')
    parser_megalodon.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_megalodon.add_argument('-p', '--minprob', default=0.8, help='probability threshold for calling modified or unmodified base (default = 0.8)')
    parser_megalodon.add_argument('-a', '--append', default=False, action='store_true', help='append to database')

    # options for guppy db
    parser_guppy.add_argument('-s', '--samplename', required=True, help='name for sample')
    parser_guppy.add_argument('-f', '--fast5', required=True, help='fast5 with called bases')
    parser_guppy.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_guppy.add_argument('-m', '--motif', required=True, help='motif e.g. G[A]TC or [C]G')
    parser_guppy.add_argument('-n', '--modname', required=True, help='mod name in guppy fast5 modified base alphabet (5mC, 6mA, etc)')
    parser_guppy.add_argument('-b', '--bam', required=True, help='.bam file containing alignments of reads from fast5')
    parser_guppy.add_argument('-r', '--ref', required=True, help='reference genome fasta (samtools faidx indexed)')
    parser_guppy.add_argument('--minprob', default=0.8, help='probability threshold for calling modified or unmodified base (default = 0.8)')
    parser_guppy.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_guppy.add_argument('--include_unmatched', action='store_true', default=False, help='include sites where read base does not match genome base')

    # options for nanopolish db
    parser_nanopolish.add_argument('-m', '--methdata', required=True, help='whole genome nanopolish methylation output, can be comma-delimited')
    parser_nanopolish.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_nanopolish.add_argument('-t', '--thresh', default=2.5, help='llr threshold (default = 2.5; if using --scalegroup the suggested setting is 2.0)')
    parser_nanopolish.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_nanopolish.add_argument('-s', '--scalegroup', default=False, action='store_true', help='scale threshold by number of CpGs in a group')
    parser_nanopolish.add_argument('-n', '--modname', default='CpG', help='modification type (tag if combining multiple mods, default = "CpG")')

    # options for adjusting methylation / unmethylation cutoffs in methylartist db
    parser_adjustcutoffs.add_argument('-d', '--db', required=True, help='methylartist database')
    parser_adjustcutoffs.add_argument('--mod', required=True, help='modification to plot (will list for user if incorrect)')
    parser_adjustcutoffs.add_argument('-m', '--methylated', required=True, help='mark as methylated above cutoff value')
    parser_adjustcutoffs.add_argument('-u', '--unmethylated', required=True, help='mark as unmethylated below cutoff value')

    # options for score distribution exploration funciton
    parser_scoredist.add_argument('-d', '--db', required=True, help='methylartist database(s), can be comma-delimited')
    parser_scoredist.add_argument('-n', '--n', default=1000000, help='sample size (default = 1000000)')
    parser_scoredist.add_argument('-m', '--mod', required=True, help='modification to plot (will list for user if incorrect)')
    parser_scoredist.add_argument('--xmin', default=None)
    parser_scoredist.add_argument('--xmax', default=None)
    parser_scoredist.add_argument('--svg', action='store_true', default=False)


    args = parser.parse_args()
    main(args)
