#!/usr/bin/env python

import os
import sys
import argparse
import gzip
import csv
import logging
import subprocess
import multiprocessing as mp
from tkinter import FALSE
import matplotlib
import random
import itertools
import sqlite3
import warnings
from matplotlib.colors import NoNorm

# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')

# Illustrator compatibility
new_rc_params = {'text.usetex': False, "svg.fonttype": 'none'}
matplotlib.rcParams.update(new_rc_params)

import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.patches import ConnectionPatch
import seaborn as sns

import numpy as np
import pandas as pd
import pysam
import scipy.stats as ss

from pandas.api.types import CategoricalDtype

skbio_installed = True
conflicting_call_warning = False

try:
    import skbio.alignment as skalign
    import skbio.sequence as skseq

except ModuleNotFoundError:
    skbio_installed = False

if skbio_installed:
    warnings.filterwarnings('ignore', module='skbio')

from uuid import uuid4
from collections import defaultdict as dd
from collections import Counter
from collections import namedtuple
from itertools import product
from operator import itemgetter
from copy import deepcopy

ont_fast5_api_installed = True

try:
    from ont_fast5_api.fast5_interface import get_fast5_file

except ModuleNotFoundError:
    ont_fast5_api_installed = False

from bx.intervals.intersection import Intersecter, Interval

FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Read:
    def __init__(self, read_name, cpg_loc, stat, methcall, modname, phase=None):
        self.read_name  = read_name
        self.llrs       = {}
        self.meth_calls = {}
        self.phase      = phase
        self.mod_names  = {}
        self.call_count = 0

        # used for locus per-read plot
        self.ypos   = None
        self.starts = []
        self.ends   = []

        self.add_mod(cpg_loc, stat, methcall, modname)

    def add_mod(self, cpg_loc, stat, methcall, modname):
        assert methcall in (-1,0,1)

        global conflicting_call_warning

        if cpg_loc in self.llrs:
            if not conflicting_call_warning:
                logger.warning('''warning: multiple modification calls for the same site + mod, read: %s  cpg_loc: %d  mod: %s
                        consider excluding non-primary alignments from your analysis, or if using a .bam with mod tags try --primary_only
                        this warning is only shown once per thread'''% (self.read_name, cpg_loc, modname))

            conflicting_call_warning = True

        self.llrs[cpg_loc]       = stat
        self.meth_calls[cpg_loc] = methcall
        self.mod_names[cpg_loc]  = modname

        if methcall != 0:
            self.call_count += 1

    def overlap(self, other):
        return min(max(self.ends), max(other.ends)) - max(min(self.starts), min(other.starts)) > 0
    
    def mod_count(self):
        return len([c for c in self.meth_calls.values() if c == 1])

    def site_count(self):
        return len(self.meth_calls)

    def mod_frac(self):
        if self.site_count() == 0:
            return 0

        return self.mod_count()/self.site_count()


class Gene:
    def __init__(self, ensg, name, strand):
        self.ensg = ensg
        self.name = name
        self.strand = strand
        self.tx_start = None
        self.tx_end = None
        self.cds_start = None
        self.cds_end = None
        self.exons = []

    def add_exon(self, block):
        assert len(block) == 2

        if block[0] < block[1]:
            block[0], block[1] = block[1], block[0]

        self.exons.append(block)
        self.exons = sorted(self.exons, key=itemgetter(0))

    def add_tx(self, block):
        assert len(block) == 2
        assert block[0] < block[1]
        if self.tx_start is None or self.tx_start > block[0]:
            self.tx_start = block[0]

        if self.tx_end is None or self.tx_end < block[1]:
            self.tx_end = block[1]

    def add_cds(self, block):
        assert len(block) == 2
        if block[0] > block[1]:
            logger.warning('CDS block start > end in gene %s' % self.ensg)
            return None

        if self.cds_start is None or self.cds_start > block[0]:
            self.cds_start = block[0]

        if self.cds_end is None or self.cds_end < block[1]:
            self.cds_end = block[1]

    def has_tx(self):
        return None not in (self.tx_start, self.tx_end)

    def has_cds(self):
        return None not in (self.cds_start, self.cds_end)

    def merge_exons(self):
        new_exons = []
        if len(self.exons) == 0:
            return

        last_block = self.exons[0]

        for block in self.exons[1:]:
            if min(block[1], last_block[1]) - max(block[0], last_block[0]) > 0: # overlap
                last_block = [min(block[0], last_block[0]), max(block[1], last_block[1])]

            else:
                new_exons.append(last_block)

            last_block = block

        new_exons.append(last_block)

        self.exons = new_exons


RC = {'A':'T', 'T':'A', 'C':'G', 'G':'C', 'N':'N'}

class Motif:
    def __init__(self, motif):

        self.motif = motif

        self.left_bases  = ''
        self.right_bases = ''
        self.key_base    = ''
        self.full_seq    = ''

        assert len(motif.split('[')) == 2, 'bad motif syntax: %s' % motif
        assert len(motif.split(']')) == 2, 'bad motif syntax: %s' % motif

        self.left_base   = motif.split('[')[0]
        self.right_bases = motif.split(']')[1]
        self.key_base    = motif.split('[')[1].split(']')[0]
        self.full_seq    = self.left_bases + self.key_base + self.right_bases
        
        assert self.key_base in 'ATCG', 'bad motif syntax: %s' % motif


    def match_seq(self, seq):
        ''' returns 0-based coords of key base for matches in seq '''

        sites = dd(list) 

        seq = list(seq.upper())


        for i in range(len(seq)-len(self.full_seq)):
            nmer = ''.join(seq[i:i+len(self.full_seq)])
            if nmer == self.full_seq:
                sites[i + len(self.left_bases)].append(self.motif)

        return sites
        

class Mod:
    def __init__(self, read_id, read_pos, read_base, sites, mods):
        self.read_id     = read_id
        self.read_pos    = read_pos
        self.read_base   = read_base
        self.sites       = sites
        self.mods        = mods
        self.chrom       = None
        self.genome_pos  = None
        self.genome_base = None
        self.aln_strand  = None

    def __str__(self):
        return '\t'.join((self.read_id, self.chrom, str(self.genome_pos), self.genome_base, self.aln_strand, str(self.read_pos), self.read_base, ','.join(self.sites), '\t'.join(map(str, self.mods))))


def base_field(mod_metadata):
    long_names = mod_metadata['modified_base_long_names'].split()

    bases = []

    i = 0
    alpha_field_name = None

    if 'output_alphabet' in mod_metadata:
        alpha_field_name = 'output_alphabet'

    elif 'modified_base_alphabet' in mod_metadata:
        alpha_field_name = 'modified_base_alphabet'

    for base in list(mod_metadata[alpha_field_name]):
        if base not in 'ACTG':
            bases.append(long_names[i])
            i += 1
        else:
            bases.append(base)

    return bases


def pandas_df(out, mod_base, can_base):
    data = dd(dict)

    header = []

    for i, line in enumerate(out):
        if i == 0:
            header = line.strip().split()
            continue

        for j, val in enumerate(line.strip().split('\t')):
            data[i][header[j]] = val

    data = pd.DataFrame.from_dict(data).T
    data = plot_data = pd.DataFrame(data.to_dict())

    data['can_log_prob'] = np.log((pd.to_numeric(data[can_base])+1)/256)
    data['mod_log_prob'] = np.log((pd.to_numeric(data[mod_base])+1)/256)
    data['modstat'] = data['mod_log_prob'] - data['can_log_prob']

    return data


def guppy_f5_fetch(fast5, bam_db, args):
    bam = pysam.AlignmentFile(args.bam)
    ref = pysam.Fastafile(args.ref)

    bases = []
    modbases = dd(list)

    m = Motif(args.motif)

    out = []

    with get_fast5_file(args.fast5 + '/' + fast5, mode="r") as f5:
        for read_id in f5.get_read_ids():
            read = f5.get_read(read_id)
            latest_basecall = read.get_latest_analysis('Basecall_1D')

            mod_base_table = read.get_analysis_dataset(latest_basecall, 'BaseCalled_template/ModBaseProbs')
            called_base_table = read.get_analysis_dataset(latest_basecall, 'BaseCalled_template/Fastq')

            mod_table_path = '{}/BaseCalled_template/ModBaseProbs'.format(latest_basecall)
            called_table_path = '{}/BaseCalled_template/Fastq'.format(latest_basecall)

            mod_metadata = read.get_analysis_attributes(mod_table_path)
            called_metadata = read.get_analysis_attributes(called_table_path)

            if None in (mod_metadata, called_metadata):
                logger.error('fast5 %s does not seem to contain basecalling information' % fast5)
                return None

            if len(bases) == 0:
                bases = base_field(mod_metadata)

                mod_names = [b for b in bases if b not in ('ATCG')]

                if args.modname not in mod_names:
                    logger.error('mod "%s" not found in metadata. Available mod names: %s' % (args.modname, ','.join(mod_names)))
                    return None

                header = "readname\tchrom\tpos\tgenome_base\tstrand\tread_pos\tread_base\tmotif\t" + '\t'.join(bases) + "\n"

            seq = called_base_table.split('\n')[1]

            assert len(seq) == len(mod_base_table)

            sites = m.match_seq(seq)

            for i, mods in enumerate(mod_base_table):
                if i in sites:
                    modbases[read_id].append(Mod(read_id, i, seq[i], sites[i], mods))

    logger.info('processed %d reads from %s' % (len(modbases), fast5))

    logger.info('parsing alignments from %s' % bam_db)


    out.append(header)

    conn = sqlite3.connect(bam_db)
    c = conn.cursor()

    for readname in modbases:
        for row in c.execute("SELECT sam FROM bam WHERE readname='%s'" % readname):
            read = pysam.AlignedSegment.fromstring(row[0], bam.header)

            pos_lookup = {}
            for ap in read.get_aligned_pairs():
                if None not in ap:
                    pos_lookup[ap[0]] = ap[1]

            for modbase in modbases[read.query_name]:
                modbase.aln_strand = '+'
                aligned_read_pos = modbase.read_pos

                if read.is_reverse:
                    modbase.aln_strand = '-'
                    aligned_read_pos = (len(read.seq) - modbase.read_pos)-1 # SAM format stores read on + strand

                if aligned_read_pos in pos_lookup:
                    modbase.chrom = read.reference_name
                    modbase.genome_pos = pos_lookup[aligned_read_pos]
                    modbase.genome_base = ref.fetch(modbase.chrom, modbase.genome_pos, modbase.genome_pos+1)

                    modbase.genome_base = modbase.genome_base.upper()

                    if args.include_unmatched:
                        out.append(str(modbase) + '\n')
                        
                    else:
                        if modbase.aln_strand == '+' and modbase.genome_base == modbase.read_base:
                            out.append(str(modbase) + '\n')

                        if modbase.aln_strand == '-' and RC[modbase.genome_base] == modbase.read_base:
                            out.append(str(modbase) + '\n')

    
    data = pandas_df(out, args.modname, m.key_base)
    out_fn = args.fast5 + '/' + '.'.join(fast5.split('.')[:-1]) + '.gupmod.tsv'
    logger.info('writing output from %s to %s' % (fast5, out_fn))
    data.to_csv(out_fn, sep='\t', index=False)

    return out_fn


def exclude_ambiguous_reads(fn, chrom, start, end, min_mapq=10, spanning_only=False):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        p = read.get_reference_positions()
        if p[0] < start or p[-1] > end:
            if read.mapq >= min_mapq:
                if spanning_only:
                    if read.reference_start < start and read.reference_end > end:
                        reads.append(read.query_name)
                else:
                    reads.append(read.query_name)

    return reads


def get_ambiguous_reads(fn, chrom, start, end, min_mapq=10, w=50):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        p = read.get_reference_positions()
        if read.mapq < min_mapq or (p[0] > start-w and p[-1] < end+w):
            reads.append(read.query_name)

    return reads


def get_reads(fn, chrom, start, end, min_mapq=10, spanning_only=False, primary_only=False):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        if primary_only:
            if read.is_secondary or read.is_supplementary:
                continue

        if read.mapq >= min_mapq:
            if spanning_only:
                if read.reference_start < start and read.reference_end > end:
                    reads.append(read.query_name)
            else:
                reads.append(read.query_name)

    return reads


def get_phased_reads(fn, chrom, start, end, min_mapq=10, tag_untagged=False, ignore_tags=False, HP_only=False):
    reads = {}

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        if read.mapq >= min_mapq:    
            phase = None

            if tag_untagged or ignore_tags:
                phase = 'unphased'

            HP = None
            PS = None

            if not ignore_tags:
                for tag in read.get_tags():
                    if tag[0] == 'HP':
                        HP = tag[1]
                    if tag[0] == 'PS':
                        PS = tag[1]

            if HP is not None:
                phase = str(HP)

            if PS is not None:
                phase = phase + ':' + str(PS)

                if HP_only:
                    phase = str(HP)

            reads[read.query_name] = phase

    return reads


def single_seq_fa(fn):
    with open(fn, 'r') as fa:
        seq   = ''
        for line in fa:
            if line.startswith('>'):
                assert seq == '', 'input fa must have only one entry'
            else:
                seq = seq + line.strip()

    return seq


def rc(dna):
    ''' reverse complement '''
    complements = str.maketrans('acgtrymkbdhvACGTRYMKBDHV', 'tgcayrkmvhdbTGCAYRKMVHDB')
    return dna.translate(complements)[::-1]


def iupac(motif):
    motif = motif.upper()

    iupac = {
        'A':['A'],
        'C':['C'],
        'G':['G'],
        'T':['T'],
        'U':['T'],
        'R':['A','G'],
        'Y':['C','T'],
        'S':['G','C'],
        'W':['A','T'],
        'K':['G','T'],
        'M':['A','C'],
        'B':['C','G','T'],
        'D':['A','G','T'],
        'H':['A','C','T'],
        'V':['A','C','G'],
        'N':['A','C','G','T']
    }

    motifs = []

    for ib in list(motif):
        if ib not in iupac:
            sys.exit('base %s not an IUPAC base, please modify --motif' % b)

        if len(motifs) == 0:
            for bp in iupac[ib]:
                motifs.append(bp)
        else:
            next_motifs = []

            for bp in iupac[ib]:
                for m in motifs:
                    next_motifs.append(m + bp)

            motifs = next_motifs

    return motifs


def get_modnames(meth_db):
    if not os.path.exists(meth_db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_db)

    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    mod_names = []

    for row in c.execute("SELECT DISTINCT mod FROM modnames"):
        mod_names.append(row[0])

    return mod_names


def get_cutoffs(meth_db, mod):
    if not os.path.exists(meth_db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_db)

    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    for row in c.execute("SELECT upper,lower FROM cutoffs WHERE modname='%s'" % mod):
        return row


def densecall_filter(reads, max_density=.7):
    ''' reads is a dict of Read objects with modified base calls '''
    out = {}
    filtered_count = 0

    for name, read in reads.items():
        if read.mod_frac() <= max_density:
            out[name] = read
        else:
            filtered_count += 1
    
    logger.debug('filtered %d reads via --max_read_density %f' % (filtered_count, max_density))

    return out


def is_bam(fn):
    for f in fn.split(','):
        try:
            bam = pysam.AlignmentFile(f)
        except:
            return False

    return True


def get_bincover(bam_fn, chrom, start, end):
    start = int(start)
    end = int(end)
    bam = pysam.AlignmentFile(bam_fn)

    count = 0

    for rec in bam.fetch(chrom, start, end):
        if rec.mapping_quality < 10:
            continue

        if not rec.is_duplicate:
            count += 1

    return [start, count]


def bam_bincover(bam_fn, chrom, w_starts, w_ends, procs=1, log=False):
    assert len(w_starts) == len(w_ends)

    pool = mp.Pool(processes=int(procs))
    results = []

    for start, end in zip(w_starts, w_ends):
        res = pool.apply_async(get_bincover, [bam_fn, chrom, start, end])
        results.append(res)

    segs = []
    for res in results:
        segs.append(res.get())

    segs = sorted(segs, key=itemgetter(0))
    cover = np.asarray([s[1] for s in segs])

    if log:
        cover = np.log2(cover+1)

    return cover


def bam_pileupcover(bam_fn, chrom, w_starts, w_ends, procs=1, log=False):
    assert len(w_starts) == len(w_ends)

    segtree = Intersecter()
    for start, end in zip(w_starts, w_ends):
        segtree.add_interval(Interval(start-1, end+1))

    minpos = min(w_starts)
    maxpos = max(w_ends)

    region = '%s:%d-%d' % (chrom, minpos, maxpos)

    segs = []
    samtools_cmd = ['samtools', 'mpileup', '-a', '-q10', '-r', region, bam_fn]
    FNULL = open(os.devnull, 'w')
    p = subprocess.Popen(samtools_cmd, stdout=subprocess.PIPE, stderr=FNULL)

    for pline in p.stdout:
        pline = pline.decode()
        
        chrom = None
        pos = None
        seq = None

        cols = pline.strip().split()

        if len(cols) < 6 or cols[-1] == '*':
            chrom, pos = cols[:2]
            seq = ''

        else:
            chrom, pos, r, dp, seq, qual = cols

        pos = int(pos)
        seq = seq.upper()

        if segtree.find(pos, pos):
            base_depth = len([b for b in seq if b in ('A','T','C','G')])
            segs.append([pos, base_depth])

    segs = sorted(segs, key=itemgetter(0))
    cover = np.asarray([s[1] for s in segs])

    if log:
        cover = np.log2(cover+1)

    return cover
    

def mods_methbam(bam_fn):
    bam = pysam.AlignmentFile(bam_fn)

    logger.info('fetching mod types from %s, if this takes awhile MM/ML tags may be missing.' % bam_fn)

    mm_warned = False

    for rec in bam.fetch():
        mm = None

        try:
            mm = str(rec.get_tag('MM')).rstrip(';')
        except KeyError:
            try:
                mm = str(rec.get_tag('Mm')).rstrip(';')
            except KeyError:
                if not mm_warned:
                    logger.debug('cannot find Mm tag in at least one read (example: %s), ensure this bam has MM and ML tags!' % rec.qname)
                    mm_warned = True
                    continue
        
        mods = []

        if mm is None:
            logger.debug('null Mm tag in at least one read (example: %s), ensure this bam has proper MM and ML tags!' % rec.qname)
            mm_warned = True
            continue

        mod_strings = mm.split(';')

        for mod_string in mod_strings:
            m = mod_string.split(',')
            mod_info = m[0]

            mod_strand = '+'
            if '-' in mod_info:
                mod_strand = '-'
            
            mod_base, mod_type = mod_info.split(mod_strand)
            mod_type = mod_type.rstrip('?.')

            mods.append(mod_type)

        mods = list(set(mods))

        logger.info('found mods: %s' % ','.join(mods))

        return mods # assumes all mods are represented for each read that has an Mm tag

    if not mods:
        logger.warning('no mods found in %s, missing MM/ML tags?' % bam_fn)
        sys.exit(1)


def split_ml(mod_strings, ml):
    mls = []

    total_ms = sum([len(ms.split(',')[1:]) for ms in mod_strings])

    if total_ms == 0:
        return mls

    assert total_ms == len(ml), 'mod bam formatting error total_ms: %d, len(ml): %d' % (total_ms, len(ml))

    i = 0
    for mod_string in mod_strings:
        m = mod_string.split(',')[1:] # discard first item (desc of mod base)
        mls.append(ml[i:i+len(m)])
        i += len(m)

        assert len(m) == len(mls[-1]), 'mod bam formatting error'
    
    return mls


def sample_bam(bam_fn, motif, ref, n):
    ref = pysam.Fastafile(ref)

    n = int(n)
    c = 0

    sample = []

    logger.info('sampling %d mods from %s' % (n, bam_fn))

    for chrom in ref.references:
        for res in parse_methbam(bam_fn, [], chrom, 0, ref.get_reference_length(chrom), motifsize=len(motif), restrict_motif=motif, primary_only=True):
            sample.append(res[3])
            c += 1
            if c >= n:
                break
        else:
            continue

        break

    return np.asarray(sample)


def parse_methbam(bam_fn, reads, chrom, start, end, motifsize=2, meth_thresh=0.8, can_thresh=0.8, restrict_motif=None, restrict_ref=None, primary_only=False):
    bam = pysam.AlignmentFile(bam_fn)

    if restrict_motif:
        if len(restrict_motif) != motifsize:
            motifsize = len(restrict_motif)

    if restrict_ref is not None:
        restrict_ref = pysam.Fastafile(restrict_ref)

    mm_warned = False

    for rec in bam.fetch(chrom, start, end):
        if rec.is_unmapped:
            continue

        if rec.seq is None:
            continue

        if len(reads) > 0 and rec.qname not in reads:
            continue

        if primary_only:
            if rec.is_secondary or rec.is_supplementary:
                continue

        ap = dict([(k, v) for (k, v) in rec.get_aligned_pairs() if k is not None])

        mm = None
        ml = None

        try:
            mm = str(rec.get_tag('MM')).rstrip(';')

        except KeyError:
            try:
                mm = str(rec.get_tag('Mm')).rstrip(';')
            except:
                if not mm_warned:
                    logger.debug('cannot find MM tag in at least one read, ensure this bam has MM and ML tags!')
                    mm_warned = True
                    continue

        try:
            ml = rec.get_tag('ML')
        except KeyError:
            try:
                ml = rec.get_tag('Ml')
            except KeyError:
                continue

        if None in (mm, ml):
            continue

        mod_strings = mm.split(';')
        
        mls = split_ml(mod_strings, ml)

        seq = rec.seq

        if rec.is_reverse:
            seq = rc(seq)

        for mod_string, scores in zip(mod_strings, mls):
            m = mod_string.split(',')
            mod_info = m[0]

            mod_relpos = list(map(int, m[1:]))

            mod_strand = '+'

            if '-' in mod_info:
                mod_strand = '-'

            try:
                mod_base, mod_type = mod_info.split(mod_strand)
                mod_type = mod_type.rstrip('?.')
            except ValueError:
                logger.debug('%s: malformed mod string for read %s (%s) skipped.' % (bam_fn, rec.qname, mod_info))
                continue

            assert len(mod_type) == 1, 'multiple modfications listed this way: %s is not yet supported, please send me an example!' % mod_info

            base_pos = [i for i, b in enumerate(seq) if b == mod_base]

            i = -1

            for skip, score in zip(mod_relpos, scores):
                i += 1

                if skip > 0:
                    i += skip

                genome_pos = None

                if rec.is_reverse:
                    genome_pos = ap[len(rec.seq)-base_pos[i]-1]

                    if genome_pos is None:
                        continue

                    genome_pos -= (int(motifsize)-1)

                else:
                    genome_pos = ap[base_pos[i]]

                if genome_pos is None:
                    continue

                p_mod = score/255
                p_can = 1-p_mod

                assert p_mod <= 1.0

                methstate = 0

                if p_mod > meth_thresh:
                    methstate = 1

                if p_can > can_thresh:
                    methstate = -1

                if None not in (restrict_motif, restrict_ref):
                    if genome_pos < 0:
                        continue

                    try:
                        ref_motif = restrict_ref.fetch(rec.reference_name, genome_pos, genome_pos+motifsize)
                        if motifsize == 1 and rec.is_reverse:
                            ref_motif = rc(ref_motif)
                    except ValueError:
                        logger.warning('warning, out of bounds motif at position %d in read: %s' % (genome_pos, rec.tostring()))

                    #print(ref_motif, restrict_motif, rec.is_reverse)

                    if ref_motif.upper() != restrict_motif.upper():
                        continue

                yield (rec.qname, rec.reference_name, genome_pos, p_mod, methstate, mod_type)


def get_segmeth_calls(args, bam_fn, mod_names, meth_dbs, chrom, seg_start, seg_end, seg_name, seg_strand, phase, methbam):
    c_lookup = {}

    if not methbam:
        for meth_db in meth_dbs:
            conn = sqlite3.connect(meth_db)
            c_lookup[meth_db] = conn.cursor()

    if 'spanning_only' not in vars(args):
        args.spanning_only=False

    reads = []
    if hasattr(args, 'excl_ambig') and args.excl_ambig:
        reads = exclude_ambiguous_reads(bam_fn, chrom, seg_start, seg_end, min_mapq=int(args.min_mapq), spanning_only=args.spanning_only)
    else:
        reads = get_reads(bam_fn, chrom, seg_start, seg_end, min_mapq=int(args.min_mapq), spanning_only=args.spanning_only, primary_only=args.primary_only)

    reads = list(set(reads))

    if phase:
        phased_reads_dict = get_phased_reads(bam_fn, chrom, seg_start, seg_end, tag_untagged=(phase=='unphased'), min_mapq=int(args.min_mapq), HP_only=True)
        reads = [r for r in reads if r in phased_reads_dict and phased_reads_dict[r] == phase]

    reads = set(reads)
    seg_reads = dd(dict)

    if methbam:
        for row in parse_methbam(bam_fn, reads, chrom, seg_start, seg_end, meth_thresh=0.8, can_thresh=0.8, restrict_ref=args.ref, restrict_motif=args.motif, primary_only=args.primary_only):
            index, cg_chrom, cg_start, stat, methstate, modname = row

            if chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start > seg_end:
                continue

            cg_seg_start = cg_start - seg_start

            if index not in seg_reads[modname]:
                seg_reads[modname][index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[modname][index].add_mod(cg_seg_start, stat, methstate, modname)

    else:
        for index in reads:
            for meth_db in meth_dbs:
                c = c_lookup[meth_db]
                for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

                    cg_chrom, cg_start, stat, methstate, modname = row

                    if chrom != cg_chrom:
                        continue

                    if cg_start < seg_start or cg_start > seg_end:
                        continue

                    cg_seg_start = cg_start - seg_start

                    if index not in seg_reads[modname]:
                        seg_reads[modname][index] = Read(index, cg_seg_start, stat, methstate, modname)
                    else:
                        seg_reads[modname][index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        for modname in seg_reads:
            seg_reads[modname] = densecall_filter(seg_reads[modname], max_density=float(args.max_read_density))

    seg_result = {}
    seen_reads = {}

    for modname in mod_names:
        seg_meth_calls = dd(int)

        for name, read in seg_reads[modname].items():
            seen_reads[name] = 1

            for loc, call in read.meth_calls.items():
                if read.mod_names[loc] == modname:
                    seg_meth_calls[call] += 1

        seg_result[modname] = seg_meth_calls

    read_count = len(seen_reads)

    return seg_result, (chrom, seg_start, seg_end, seg_name, seg_strand, read_count)


def get_meth_locus(args, bam, meth_dbs, mod, phase=None, methbam=False, HP_only=False, restrict_motif=None, restrict_ref=None):
    # used for locus plots

    if phase:
        logger.info('fetching reads from %s with mod %s on phase %s' % (bam, mod, phase))
    else:
        logger.info('fetching reads from %s with mod %s' % (bam, mod))

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos = args.interval.split(':')
    elt_start, elt_end = pos.split('-')

    elt_start = int(elt_start.replace(',',''))
    elt_end = int(elt_end.replace(',',''))

    c_lookup = {}

    if not methbam:
        for meth_db in meth_dbs:
            if not os.path.exists(meth_db):
                sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_db)

            conn = sqlite3.connect(meth_db)
            c_lookup[meth_db] = conn.cursor()

    # get list of relevant reads (exludes reads not anchored outside interval)
    reads = []
    if hasattr(args, 'excl_ambig') and args.excl_ambig:
        reads = exclude_ambiguous_reads(bam, chrom, elt_start, elt_end, min_mapq=int(args.min_mapq))
    else:
        reads = get_reads(bam, chrom, elt_start, elt_end, min_mapq=int(args.min_mapq), primary_only=args.primary_only)
        
    reads = list(set(reads))

    if phase:
        phased_reads_dict = get_phased_reads(bam, chrom, elt_start, elt_end, tag_untagged=(phase=='unphased'), min_mapq=int(args.min_mapq), HP_only=HP_only)
        reads = [r for r in reads if r in phased_reads_dict and phased_reads_dict[r] == phase]

    if hasattr(args, 'unambig_highlight') and hasattr(args, 'highlight'): 
        if args.unambig_highlight and args.highlight:
            h_coords = []
            for h in args.highlight.split(','):
                if ':' in h:
                    h = h.split(':')[-1]
                    
                h_coords += map(int, h.split('-'))

            h_coords.sort()

            h_start, h_end = h_coords[0], h_coords[-1]

            excl_reads = get_ambiguous_reads(bam, chrom, h_start, h_end, min_mapq=int(args.min_mapq))

            new_reads = []
            for read in reads:
                if read not in excl_reads:
                    new_reads.append(read)

            reads = set(new_reads)

    seg_reads = {}

    if methbam:
        for row in parse_methbam(bam, reads, chrom, elt_start, elt_end, motifsize=args.motifsize, meth_thresh=0.8, can_thresh=0.8, restrict_motif=restrict_motif, restrict_ref=restrict_ref, primary_only=args.primary_only):
            index, cg_chrom, cg_start, stat, methstate, modname = row

            if methstate == 0:
                continue

            if modname != mod:
                continue

            if chrom != cg_chrom:
                continue

            if cg_start < elt_start or cg_start > elt_end:
                continue

            cg_seg_start = cg_start - elt_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    else:
        for index in reads:
            for meth_db in meth_dbs:
                c = c_lookup[meth_db]

                for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):
                    cg_chrom, cg_start, stat, methstate, modname = row

                    if methstate == 0:
                        continue

                    if modname != mod:
                        continue

                    if chrom != cg_chrom:
                        continue

                    if cg_start < elt_start or cg_start > elt_end:
                        continue

                    cg_seg_start = cg_start - elt_start

                    if index not in seg_reads:
                        seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
                    else:
                        seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

    return seg_reads


def get_meth_profile_composite(args, data, methbam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, phase):
    per_bam_results = {}

    for bam in data:
        logger.info('profiling %s: %s:%d-%d:%s:%s phase: %s' % (bam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, str(phase)))

        te_ref_seq = single_seq_fa(args.teref)
        ref = pysam.Fastafile(args.ref)

        conn = None
        c = None

        if not methbam:
            conn = sqlite3.connect(data[bam])
            c = conn.cursor()

        reads = []
        if args.excl_ambig:
            reads = exclude_ambiguous_reads(bam, seg_chrom, seg_start, seg_end, min_mapq=int(args.min_mapq))
        else:
            reads = get_reads(bam, seg_chrom, seg_start, seg_end, min_mapq=int(args.min_mapq), primary_only=args.primary_only)

        if phase:
            phased_reads_dict = get_phased_reads(bam, seg_chrom, seg_start, seg_end, tag_untagged=(phase=='unphased'), min_mapq=int(args.min_mapq), HP_only=True)
            reads = [r for r in reads if r in phased_reads_dict and phased_reads_dict[r] == phase]

        reads = set(reads)

        seg_reads = {}

        if methbam:
            for row in parse_methbam(bam, reads, seg_chrom, seg_start, seg_end, motifsize=len(args.motif), meth_thresh=0.8, can_thresh=0.8, restrict_motif=args.motif, restrict_ref=args.ref, primary_only=args.primary_only):
                index, cg_chrom, cg_start, stat, methstate, modname = row

                if seg_chrom != cg_chrom:
                    continue

                if cg_start < seg_start or cg_start > seg_end:
                    continue

                if modname != use_mod:
                    continue

                cg_seg_start = cg_start - seg_start

                if index not in seg_reads:
                    seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
                else:
                    seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

        else:
            for index in reads:
                for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

                    cg_chrom, cg_start, stat, methstate, modname = row

                    if seg_chrom != cg_chrom:
                        continue

                    if cg_start < seg_start or cg_start > seg_end:
                        continue

                    if modname != use_mod:
                        continue

                    cg_seg_start = cg_start - seg_start

                    if index not in seg_reads:
                        seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
                    else:
                        seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

        if args.max_read_density is not None:
            seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

        meth_table = dd(dict)
        sample = '.'.join(bam.split('.')[:-1])
        call_count = 0

        for name, read in seg_reads.items():
            for loc in read.llrs.keys():
                uuid = str(uuid4())
                meth_table[uuid]['loc'] = loc
                meth_table[uuid]['modstat'] = read.llrs[loc]
                meth_table[uuid]['read'] = name
                meth_table[uuid]['sample'] = sample
                meth_table[uuid]['call'] = read.meth_calls[loc]

                if read.meth_calls[loc] in (1, -1):
                    call_count += 1

        if call_count < int(args.mincalls):
            logger.warning('too few calls on seg: %s:%d-%d (%d)' % (seg_chrom, seg_start, seg_end, call_count))
            per_bam_results[bam] = None
            continue

        meth_table = pd.DataFrame.from_dict(meth_table).T
        meth_table['loc'] = pd.to_numeric(meth_table['loc'])
        meth_table['modstat'] = pd.to_numeric(meth_table['modstat'])

        meth_table['orig_loc'] = meth_table['loc']
        meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

        coord_to_cpg = {}
        cpg_to_coord = {}
        for orig_loc, new_loc in zip(meth_table['orig_loc'], meth_table['loc']):
            coord_to_cpg[orig_loc] = new_loc
            cpg_to_coord[new_loc]  = orig_loc

        windowed_methfrac, meth_n = slide_window(meth_table, sample, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))

        if len(windowed_methfrac) <= int(args.smoothwindowsize):
            logger.warning('too few sites after windowing: %s:%d-%d' % (seg_chrom, seg_start, seg_end))
            per_bam_results[bam] = None
            continue

        smoothed_methfrac = smooth(np.asarray(list(windowed_methfrac.values())), window_len=int(args.smoothwindowsize), window=args.smoothfunc)

        coord_meth_pos = []

        cpg_meth_pos = list(windowed_methfrac.keys())

        for cpg in cpg_meth_pos:
            if cpg in cpg_to_coord:
                if seg_strand == '+':
                    coord_meth_pos.append(cpg_to_coord[cpg])
                if seg_strand == '-':
                    coord_meth_pos.append((seg_end-seg_start)-cpg_to_coord[cpg]-2)

        # alignment to ref elt

        elt_seq = ref.fetch(seg_chrom, seg_start, seg_end)

        if seg_strand == '-':
            elt_seq = rc(elt_seq)

        te_ref_seq = te_ref_seq.upper()
        elt_seq = elt_seq.upper()

        s_ref = skseq.DNA(te_ref_seq)
        s_elt = skseq.DNA(elt_seq)

        aln_res = []

        try:
            aln_res = skalign.local_pairwise_align_ssw(s_ref, s_elt)
        except IndexError: # scikit-bio throws this if no bases align  >:|
            logger.warning('no align on seg: %s:%d-%d' % (seg_chrom, seg_start, seg_end))
            per_bam_results[bam] = None
            continue
        
        coord_ref, coord_elt = aln_res[2]
        
        len_ref = coord_ref[1] - coord_ref[0]
        len_elt = coord_elt[1] - coord_elt[0]

        if len_ref / len(te_ref_seq) < float(args.lenfrac):
            logger.warning('ref align too short on seg: %s:%d-%d (%f)' % (seg_chrom, seg_start, seg_end, len_ref / len(te_ref_seq)))
            per_bam_results[bam] = None
            continue

        if len_elt / len(elt_seq) < float(args.lenfrac):
            logger.warning('elt align too short on seg: %s:%d-%d (%f)' % (seg_chrom, seg_start, seg_end, len_elt / len(elt_seq)))
            per_bam_results[bam] = None
            continue

        tab_msa = aln_res[0]

        elt_to_ref_coords = {}

        pos_ref = coord_ref[0]
        pos_elt = coord_elt[0]

        for pos in tab_msa.iter_positions():
            pos = list(pos)
            b_ref = pos[0]
            b_elt = pos[1]

            if '-' not in pos:
                elt_to_ref_coords[pos_elt] = pos_ref
                pos_ref += 1
                pos_elt += 1

            if b_elt == '-':
                pos_ref += 1

            if b_ref == '-':
                elt_to_ref_coords[pos_elt] = 'na'
                pos_elt += 1

        revised_coord_meth_pos = []
        meth_profile = []

        for pos, meth in zip(coord_meth_pos, smoothed_methfrac):
            if pos not in elt_to_ref_coords:
                continue

            revised_pos = elt_to_ref_coords[pos]

            if revised_pos != 'na':
                revised_coord_meth_pos.append(revised_pos)
                meth_profile.append(meth)

        bam_noext = '.'.join(os.path.basename(bam).split('.')[:-1])
        elt_info = '_'.join( map(str, (seg_chrom, seg_start, seg_end, seg_strand)))
        if phase:
            bam_noext += '.phase' + phase
        per_bam_results[bam_noext] = (revised_coord_meth_pos, meth_profile, elt_info)

    return per_bam_results


def get_meth_calls_wg(args, bam_fn, meth_fn, chrom, seg_start, seg_end, phased, mod, motifsize):
    methbam = False

    if args.methdb is None:
        methbam = True
    
    if not methbam and not os.path.exists(meth_fn):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_fn)

    conn = None
    c = None

    if not methbam:
        conn = sqlite3.connect(meth_fn)
        c = conn.cursor()

    reads = get_phased_reads(bam_fn, chrom, seg_start, seg_end, tag_untagged=phased, min_mapq=int(args.min_mapq), HP_only=True)

    seg_reads = {}

    if methbam:
        for row in parse_methbam(bam_fn, set(reads), chrom, seg_start, seg_end, meth_thresh=0.8, can_thresh=0.8, restrict_ref=args.ref, restrict_motif=args.motif, motifsize=motifsize, primary_only=args.primary_only):
            index, cg_chrom, cg_start, stat, methstate, modname = row

            if mod is not None and mod != modname:
                continue

            if chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start >= seg_end:
                continue

            cg_seg_start = cg_start - seg_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    else:
        for index in reads:
            for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

                cg_chrom, cg_start, stat, methstate, modname = row

                if mod is not None and mod != modname:
                    continue

                if chrom != cg_chrom:
                    continue

                if cg_start < seg_start or cg_start > seg_end:
                    continue

                cg_seg_start = cg_start - seg_start

                if index not in seg_reads:
                    seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname, phase=reads[index])
                else:
                    seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

    meth_data = {}
    meth_data[1] = dd(list)
    meth_data[2] = dd(list)

    meth_table = {}
    meth_table[1] = dd(dict)
    meth_table[2] = dd(dict)

    if phased:
        for name, read in seg_reads.items():
            if read.phase is None:
                continue

            if read.phase == 'unphased':
                continue

            if read.phase not in ['1','2']:
                continue

            for loc in read.llrs.keys():
                phase = int(read.phase)
                assert phase in (1,2)

                meth_data[phase][loc].append(read.meth_calls[loc])


        for phase in (1,2):
            for loc in meth_data[phase]:
                pos = loc+seg_start
                N = len([call for call in meth_data[phase][loc] if call != 0]) # call 0 == ambiguous
                X = len([call for call in meth_data[phase][loc] if call == 1]) # call 1 == methylated

                if N > 0:
                    meth_table[phase][pos]['chr'] = chrom
                    meth_table[phase][pos]['N'] = N 
                    meth_table[phase][pos]['X'] = X 

    else:
        # pass through as phase 1 if not analysing phases
        phase = 1

        for name, read in seg_reads.items():
            for loc in read.llrs.keys():
                meth_data[phase][loc].append(read.meth_calls[loc])

        for loc in meth_data[phase]:
            pos = loc+seg_start
            N = len([call for call in meth_data[phase][loc] if call != 0])
            X = len([call for call in meth_data[phase][loc] if call == 1])

            if N > 0:
                meth_table[phase][pos]['chr'] = chrom
                meth_table[phase][pos]['N'] = N 
                meth_table[phase][pos]['X'] = X 

    return [meth_table[1], meth_table[2]]


def slide_window(meth_table, sample, width=20, slide=2):
    # used for locus plots, composite plots
    midpt_min = min(meth_table['loc'])
    midpt_max = max(meth_table['loc'])

    win_start = int(midpt_min - width/2)
    win_end = win_start + width

    meth_frac = {}
    meth_n = {}

    while int((win_start+win_end)/2) < midpt_max:
        win_start += slide
        win_end += slide

        meth_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == 1)])
        unmeth_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == -1)])

        midpt = int((win_start+win_end)/2)

        if meth_count + unmeth_count > 0:
            meth_frac[midpt] = meth_count/(meth_count+unmeth_count)
            meth_n[midpt] = meth_count+unmeth_count

    return meth_frac, meth_n


def slide_window_cover(cover_table, sample, width=20, slide=2):
    # used for locus plots, composite plots
    midpt_min = min(cover_table['loc'])
    midpt_max = max(cover_table['loc'])

    win_start = int(midpt_min - width/2)
    win_end = win_start + width

    cover_windows = {}

    while int((win_start+win_end)/2) < midpt_max:
        win_start += slide
        win_end += slide

        midpt = int((win_start+win_end)/2)

        c = cover_table.loc[(cover_table['sample'] == sample) & (cover_table['loc'] > win_start) & (cover_table['loc'] < win_end)]
        cover_windows[midpt] = np.mean(c['cover'])

    #print(cover_windows)
    return cover_windows


def slide_window_phased(meth_table, phase, width=20, slide=2):
    midpt_min = min(meth_table['loc'])
    midpt_max = max(meth_table['loc'])

    win_start = int(midpt_min - width/2)
    win_end = win_start + width

    meth_frac = {}
    meth_n = {}

    while int((win_start+win_end)/2) < midpt_max:
        win_start += slide
        win_end += slide

        meth_count = len(meth_table.loc[(meth_table['phase'] == phase) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == 1)])
        unmeth_count = len(meth_table.loc[(meth_table['phase'] == phase) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == -1)])

        midpt = int((win_start+win_end)/2)

        if meth_count + unmeth_count > 0:
            meth_frac[midpt] = meth_count/(meth_count+unmeth_count)
            meth_n[midpt] = meth_count+unmeth_count

    return meth_frac, meth_n


def smooth(x, window_len=8, window='hanning'):
    # used for locus plots
    ''' modified from scipy cookbook: https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html '''

    assert window_len % 2 == 0, '--smoothwindowsize must be an even number'
    assert x.ndim == 1

    if x.size < window_len:
        logger.warning('cannot smooth segment: fewer data points than window_len')
        return x

    if window_len<3:
        return x

    assert window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']

    s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]]
    
    if window == 'flat': #moving average
        w=np.ones(window_len,'d')
    else:
        w=eval('np.'+window+'(window_len)')

    y=np.convolve(w/w.sum(),s,mode='valid')

    return y[(int(window_len/2)-1):-(int(window_len/2))]


def mask_methfrac(data, cutoff=20):
    # used for locus plots
    data = np.asarray(data)
    data = data > int(cutoff)

    segs = []

    in_seg = False
    seg_start = 0

    for i in range(len(data)):
        if data[i]:
            if in_seg:
                segs.append(list(range(seg_start, i)))

            in_seg = False

        else:
            if not in_seg:
                seg_start = i

            in_seg = True

    if in_seg:
        segs.append(list(range(seg_start, len(data))))

    return segs


def get_bed_annotations(fn, chrom, elt_start, elt_end):
    ann = []

    elt_start = int(elt_start)
    elt_end = int(elt_end)

    with open(fn) as bed:
        for bed_rec in bed:
            c = bed_rec.strip().split()
            bed_chrom, bed_start, bed_end = c[:3]

            if bed_chrom != chrom:
                continue

            bed_start = int(bed_start)
            bed_end = int(bed_end)

            if bed_end < elt_start:
                continue

            if bed_start > elt_end:
                continue

            if bed_start < elt_start:
                bed_start = elt_start

            if bed_end > elt_end:
                bed_end = elt_end

            bed_label = None
            bed_colour = None
            strand = None

            if len(c) > 3:
                bed_label = c[3]
            
            if len(c) > 4:
                if c[4] not in ('+', '-'):
                    sys.exit('format for --bed input must be: chrom start end label strand colour')
                
                strand = c[4]
            
            if len(c) > 5:
                bed_colour = c[5]

            BED = namedtuple("BED", "start end label strand colour")
            ann.append(BED(start=bed_start, end=bed_end, label=bed_label, strand=strand, colour=bed_colour))
    
    return ann


def build_genes(gtf, chrom, start, end):
    # used for locus plots
    genes = {}

    for line in gtf.fetch(chrom, start, end):

        chrom, source, feature, start, end, score, strand, frame, attribs = line.split('\t')

        block = [int(start), int(end)]

        attribs = attribs.strip()

        attr_dict = {}

        for attrib in attribs.split(';'):
            if attrib:
                key, val = attrib.strip().split()[:2]
                key = key.strip()
                val = val.strip().strip('"')
                attr_dict[key] = val

        if 'gene_id' not in attr_dict:
            continue

        if 'gene_name' not in attr_dict:
            attr_dict['gene_name'] = attr_dict['gene_id']

        ensg = attr_dict['gene_id']
        name = attr_dict['gene_name']

        if ensg not in genes:
            genes[ensg] = Gene(ensg, name, strand)

        if feature == 'exon':
            genes[ensg].add_exon(block)

        if feature == 'CDS':
            genes[ensg].add_cds(block)

        if feature == 'transcript':
            genes[ensg].add_tx(block)

    return genes


def sample_db(db, mod, n=1000000):
    if not os.path.exists(db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % db)

    conn = sqlite3.connect(db)
    c = conn.cursor()

    logger.info('sample %d values from %s where modname="%s"' % (n, db, mod))
    return np.asarray(c.execute('SELECT stat FROM methdata WHERE modname="%s" ORDER BY RANDOM() LIMIT %d' % (mod, n)).fetchall()).flatten()


## tools

def scoredist(args):
    sample = {}
    #model = {}

    meth_cutoffs = []
    unmeth_cutoffs = []

    avail_mods = []

    out_fn = None

    if args.db is None and args.bam is None:
        sys.exit('please specify either -d/--db or -b/--bam')

    if args.motif is not None:
        assert iupac(args.motif)

    if args.db:
        if args.bam is not None:
            sys.exit('please specify either -d/--db or -b/--bam but not both')

        for db in args.db.split(','):
            avail_mods += get_modnames(db)

        avail_mods = list(set(avail_mods))

        if args.mod is None or args.mod not in avail_mods:
            logger.warning('mod %s not found, available mods: %s' % (str(args.mod), ','.join(avail_mods)))
            sys.exit()

        for db in args.db.split(','):
            assert os.path.exists(db), '%s not found' % db
            upper, lower = get_cutoffs(db, args.mod)
            meth_cutoffs.append(upper)
            unmeth_cutoffs.append(lower)

            sample[os.path.basename(db)] = sample_db(db, args.mod, n=int(args.n))
        
        out_fn = '_'.join(map(os.path.basename, args.db.split(','))) + '.scoredist'

    elif args.bam:

        if None in (args.ref, args.motif):
            logger.warning('--ref and --motif are required when using --bams')
            sys.exit(1)

        for bam in args.bam.split(','):
            meth_cutoffs.append(0.8)
            unmeth_cutoffs.append(0.2)
            sample[os.path.basename(bam)] = sample_bam(bam, args.motif, args.ref, int(args.n))
        
        out_fn = '_'.join(map(os.path.basename, args.bam.split(','))) + '.scoredist'


    meth_cutoffs = list(set(meth_cutoffs))
    unmeth_cutoffs = list(set(unmeth_cutoffs))

    if len(meth_cutoffs) > 1:
        logger.warning('multipe upper (methylation) cutoffs:' + ','.join(map(str, meth_cutoffs)))

    if len(unmeth_cutoffs) > 1:
        logger.warning('multipe upper (unmethylation) cutoffs:' + ','.join(map(str, unmeth_cutoffs)))

    fig = plt.figure()
    ax = fig.add_subplot(111)
    dens = sns.kdeplot(data=sample, palette=args.palette, lw=float(args.lw))

    top = ax.get_ylim()[1]

    ax.vlines(meth_cutoffs, ymin=0, ymax=top, colors='black', linestyles='dashed')
    ax.vlines(unmeth_cutoffs, ymin=0, ymax=top, colors='black', linestyles='dashed')

    xmin, xmax = ax.get_xlim()

    if args.xmin:
        xmin = float(args.xmin)
    
    if args.xmax:
        xmax = float(args.xmax)

    ax.set_xlim(xmin, xmax)

    if args.svg:
        out_fn += '.svg'
    else:
        out_fn += '.png'

    if args.outfile is not None:
        out_fn = args.outfile

    logger.info('plot written to %s' % out_fn)
    plt.savefig(out_fn)


def adjustcutoffs(args):
    if not os.path.exists(args.db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % args.db)

    avail_mods = get_modnames(args.db)

    if args.mod not in avail_mods:
        logger.warning('mod %s not found, available mods: %s' % (args.mod, ','.join(avail_mods)))
        sys.exit()

    conn = sqlite3.connect(args.db)
    c = conn.cursor()

    logger.info('%s: reset methylation states to 0 for mod %s' % (args.db, args.mod))
    c.execute("UPDATE methdata SET methstate=0 where modname='%s'" % args.mod)

    logger.info('%s: mark sites with stat > %f methylated (1) for mod %s' % (args.db, float(args.methylated), args.mod))
    c.execute('UPDATE methdata SET methstate=1 where stat > %f and modname="%s"' % (float(args.methylated), args.mod))

    logger.info('%s: mark sites with stat < %f unmethylated (-1) for mod %s' % (args.db, float(args.unmethylated), args.mod))
    c.execute('UPDATE methdata SET methstate=-1 where stat < %f and modname="%s"' % (float(args.unmethylated), args.mod))

    logger.info('%s: update cutoff table for mod %s' % (args.db, args.mod))
    c.execute('UPDATE cutoffs SET upper=%f WHERE modname="%s"' % (float(args.methylated), args.mod))
    c.execute('UPDATE cutoffs SET lower=%f WHERE modname="%s"' % (float(args.unmethylated), args.mod))

    conn.commit()

def db_guppy(args):
    if not ont_fast5_api_installed:
        sys.exit('ont_fast5_api is not installed but is required for this function. Please install e.g. via "pip install ont-fast5-api"')

    if not args.force:
        sys.exit('This function (db-guppy) is depreciated. Please use use guppy to create a .bam file with modification tags and use that as input to methylartist')

    assert os.path.exists(args.fast5), 'path not found: %s' % args.fast5

    bam_db = args.samplename + '.bamcache.db'

    if os.path.exists(bam_db):
        logger.info('using existing bam cache db %s' % bam_db)

    else:
        logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))
        conn = sqlite3.connect(bam_db)
        c = conn.cursor()

        logger.info('caching %s to %s' % (args.bam, bam_db))
        c.execute('''CREATE TABLE bam (readname text, sam text)''')
        c.execute('''CREATE INDEX read_index ON bam(readname)''')

        for bam_fn in args.bam.split(','):
            bam = pysam.AlignmentFile(bam_fn)

            commit_interval = 100000

            for i, read in enumerate(bam.fetch(), 1):
                if read.is_secondary or read.is_supplementary or read.is_duplicate:
                    continue

                read.query_qualities = None

                read_data = (read.query_name, read.to_string())
                c.execute("INSERT INTO bam VALUES ('%s', '%s')" % read_data)

                if i % commit_interval == 0:
                    logger.info('commiting %d records to %s...' % (commit_interval, bam_db))
                    conn.commit()

            logger.info('commiting remaining records to %s...' % bam_db)

        conn.commit()
        conn.close()

    fast5s = []

    for fn in os.listdir(args.fast5):
        if fn.endswith('.fast5'):
            fast5s.append(fn)

    logger.info('found %d fast5 files in %s' % (len(fast5s), args.fast5))

    logger.info('fetching base modifications...')

    pool = mp.Pool(processes=int(args.procs))

    results = []

    for fast5 in fast5s:
        res = pool.apply_async(guppy_f5_fetch, [fast5, bam_db, args])
        results.append(res)


    outfiles = []

    for res in results:
        outfiles.append(res.get())

    db_fn = args.samplename + '.guppy.db'

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    conn = sqlite3.connect(db_fn)
    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for fn in outfiles:
        if fn is None:
            continue

        with open(fn) as tsv:
            logger.info('loading %s into %s' % (fn, db_fn))
            csv_reader = csv.DictReader(tsv, delimiter='\t')

            for row in csv_reader:
                readname = row['readname']
                chrom    = row['chrom']
                pos      = int(row['pos'])
                strand   = row['strand']

                mod_prob = float(row['mod_log_prob'])
                can_prob = float(row['can_log_prob'])
                minprob  = float(args.minprob)
                mod_base = args.modname

                llr = float(row['mod_log_prob']) - float(row['can_log_prob'])

            # adjust position of - strand calls based on mod motif size (default = 2 as CG is probably the most frequent use case)
            if strand == '-':
                pos -= (int(args.motifsize)-1)

                methcall = 0

                if mod_prob >= minprob:
                    methcall = 1

                if can_prob >= minprob:
                    methcall = -1

                assert (mod_prob + can_prob) < 1.01 # allow for some rounding error

                ins_data = (chrom, pos, strand, readname, llr, methcall, mod_base)
                c.execute("INSERT INTO methdata VALUES ('%s', %d, '%s', '%s', %.2f, %d, '%s')" % ins_data)

            c.execute("INSERT INTO modnames VALUES ('%s')" % args.modname)

            if not args.append:
                c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (llr, -1*llr, args.modname))

            logger.info('commiting records from %s to %s' % (fn, db_fn))
            conn.commit()

    conn.close()        
    logger.info('finished.')


def db_megalodon(args):
    basename = '.'.join(args.methdata.split('.')[:-1])
    if basename == 'per_read_modified_base_calls' and not args.db:
        sys.exit('default megalodon filename (per_read_modified_base_calls) is not informative: please specify a database name with --db')    

    db_fn = basename + '.megalodon.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        csv_reader = csv.DictReader(methdata, delimiter='\t')

        progress_interval = 1000000

        modnames = {}

        ins_data = []

        minprob = float(args.minprob)

        for i, row in enumerate(csv_reader):
            readname = row['read_id']
            chrom    = row['chrm']
            pos      = int(row['pos'])
            strand   = row['strand']
            mod_prob = np.exp(float(row['mod_log_prob']))
            can_prob = np.exp(float(row['can_log_prob']))
            
            mod_base = row['mod_base']

            modnames[mod_base] = True

            # adjust position of - strand calls based on mod motif size (default = 2 as CG is probably the most frequent use case)
            if strand == '-':
                pos -= (int(args.motifsize)-1)

            methcall = 0

            if mod_prob >= minprob:
                methcall = 1

            if can_prob >= minprob:
                methcall = -1

            ins_data.append((chrom, pos, strand, readname, mod_prob, methcall, mod_base))

            if i % progress_interval == 0:
                conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
                ins_data = []
                logger.info('processed %d records from %s' % (i, tsv))

        for mod in modnames:
            c.execute("INSERT INTO modnames VALUES ('%s')" % mod)
            c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (minprob, 1-minprob, mod))

        if len(ins_data) > 0:
            conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
            ins_data = []

        logger.info('commiting records from %s to %s' % (tsv, db_fn))
        conn.commit()

    conn.close()


def db_custom(args):
    basename = '.'.join(args.methdata.split('.')[:-1])
    if basename == 'per_read_modified_base_calls' and not args.db:
        sys.exit('default megalodon filename (per_read_modified_base_calls) is not informative: please specify a database name with --db')    

    db_fn = basename + '.custom.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        progress_interval = 1000000

        modnames = {}

        ins_data = []

        c_readname = int(args.readname)
        c_chrom    = int(args.chrom)
        c_pos      = int(args.pos)
        c_strand   = int(args.strand)
        c_modprob  = int(args.modprob)

        c_canprob = None

        if args.canprob is not None:
            c_canprob  = int(args.canprob)

        mod_base = None

        if args.modbasecol is not None:
            c_modbasecol = int(args.modbasecol)

        minmodprob = float(args.minmodprob)
        mincanprob = float(args.minmodprob)

        if args.mincanprob is not None:
            mincanprob = float(args.mincanprob)


        for i, row in enumerate(methdata):
            if args.header and i == 0:
                continue

            cols = row.strip().split()

            if args.delimiter is not None:
                cols = row.strip().split(args.delimiter)

            if len(cols) < 5:
                sys.exit('table %s has < 5 columns at line %d' % (tsv, i))

            readname = cols[c_readname]
            chrom    = cols[c_chrom]
            pos      = int(cols[c_pos])
            strand   = cols[c_strand]
            mod_prob = float(cols[c_modprob])
            can_prob = 1.0-mod_prob

            if c_canprob is not None:
                can_prob = float(cols[c_canprob])

            if args.modbase is not None:
                mod_base = args.modbase

            if mod_base is None:
                sys.exit('must specify either --modbase or --modbasecol')

            modnames[mod_base] = True

            # adjust position of - strand calls based on mod motif size (default = 2 as CG is probably the most frequent use case)
            if strand == '-':
                pos -= (int(args.motifsize)-1)

            methcall = 0

            if mod_prob >= minmodprob:
                methcall = 1

            if can_prob >= mincanprob:
                methcall = -1

            ins_data.append((chrom, pos, strand, readname, mod_prob, methcall, mod_base))

            if i % progress_interval == 0:
                conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
                ins_data = []
                logger.info('processed %d records from %s' % (i, tsv))

        for mod in modnames:
            c.execute("INSERT INTO modnames VALUES ('%s')" % mod)

        c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (float(minmodprob), 1-float(mincanprob), mod))

        if len(ins_data) > 0:
            conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
            ins_data = []

        logger.info('commiting records from %s to %s' % (tsv, db_fn))
        conn.commit()

    conn.close()


def db_nanopolish(args):
    basename = '.'.join(args.methdata.split('.')[:-1])

    if basename.endswith('.tsv'):
        basename = '.'.join(basename.split('.')[:-1])

    db_fn = basename + '.nanopolish.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        csv_reader = csv.DictReader(methdata, delimiter='\t')

        progress_interval = 500000

        ins_data = []

        for i, row in enumerate(csv_reader, 1):
            try:
                r_start  = int(row['start'])
                llr      = float(row['log_lik_ratio'])
                seq      = row['sequence']
                mod_base = args.modname
            except:
                logger.warning('bad line %d: %s' % (i, str(row)))
                continue
                
            if args.scalegroup:
                llr = llr/float(row['num_motifs'])

            methcall = 0

            if llr > float(args.thresh):
                methcall = 1

            elif llr < float(args.thresh)*-1:
                methcall = -1

            # get per-CG position (nanopolish/calculate_methylation_frequency.py)
            if args.motif not in seq:
                sys.exit('motif %s not found in kmer %s, please check --motif setting' % (args.motif, seq))

            cg_pos = seq.find(args.motif)
            first_cg_pos = cg_pos

            while cg_pos != -1:
                cg_start = r_start + cg_pos - first_cg_pos
                cg_pos = seq.find(args.motif, cg_pos + 1)

                ins_data.append((row['chromosome'], cg_start, row['strand'], row['read_name'], llr, methcall, mod_base))

            if i % progress_interval == 0:
                conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
                ins_data = []
                logger.info('processed %d records from %s' % (i, tsv))

        if len(ins_data) > 0:
            conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
            ins_data = []

        c.execute("INSERT INTO modnames VALUES ('%s')" % args.modname)

        c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (float(args.thresh), float(args.thresh)*-1, args.modname))

        logger.info('commiting records from %s to %s' % (tsv, db_fn))

        conn.commit()

    conn.close()


def segmeth(args):
    '''
    segment methylation stats over genomic intervals
    '''

    stats = [
    'meth_calls',
    'unmeth_calls',
    'no_calls',
    'methfrac',
    'readcount'
    ]

    mod_names = []

    data = dd(list)

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams')

    methbam = False

    if args.motif is not None:
        assert iupac(args.motif)

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both')

        with open(args.data) as _:
            for line in _:
                bam, meth = line.strip().split()[:2]
                for m_db in meth.split(','):
                    data[bam].append(m_db)

                    for m in get_modnames(m_db):
                        mod_names.append(m)

    if args.bams is not None:
        if None in (args.ref, args.motif):
            logger.warning('--ref and --motif are required when using --bams')
            sys.exit(1)

        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')

        elif is_bam(args.bams):
            bams = args.bams.split(',')

        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    bams.append(line.strip().split()[0])

        for bam in bams:
            if ':' in bam:
                bam, _ = bam.split(':')

            if not os.path.exists(bam+'.bai'):
                sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

        mod_names = mods_methbam(bam)

    mod_names = list(set(mod_names))

    logger.info('found mod names: %s' % ','.join(mod_names))

    base_names = ['.'.join(bam.split('.')[:-1]) for bam in data]

    if args.phased:
        phased_base_names = []
        for bn in base_names:
            phased_base_names.append(bn+'_ph1')
            phased_base_names.append(bn+'_ph2')
        base_names = phased_base_names

    data_basename = None

    if methbam:
        data_basename = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1])
        if len(args.bams.split(',')) > 1:
            data_basename += '.cohort'

    else:
        data_basename = '.'.join(os.path.basename(args.data).split('.')[:-1])    

    ivl_basename = '.'.join(os.path.basename(args.intervals).split('.')[:-1])
    outfn = '.'.join((ivl_basename, data_basename))

    if args.excl_ambig:
        outfn += '.excl_ambig'

    if args.spanning_only:
        outfn += '.spanning_only'
    
    if args.phased:
        outfn += '.phased'

    outfn += '.segmeth.tsv'

    if args.outfile is not None:
        outfn = args.outfile

    out = open(outfn, 'w')

    logger.info('segmeth output filename: %s' % outfn)

    pool = mp.Pool(processes=int(args.procs))

    results = []

    phases = [None]

    if args.phased:
        phases = ['1', '2']

    for bam_fn, meth_fn in data.items():
        for phase in phases:
            base_name = '.'.join(bam_fn.split('.')[:-1])

            if phase is not None:
                base_name += '_ph' + phase

            with open(args.intervals) as _:
                for line in _:
                    c = line.strip().split()
                    chrom, seg_start, seg_end = c[:3]

                    seg_name = 'NA'
                    seg_strand = 'NA'

                    if len(c) > 3:
                        seg_name = c[3]

                    if len(c) > 4:
                        seg_strand = c[4]

                    seg_start = int(seg_start)
                    seg_end = int(seg_end)

                    res = pool.apply_async(get_segmeth_calls, [args, bam_fn, mod_names, meth_fn, chrom, seg_start, seg_end, seg_name, seg_strand, phase, methbam])

                    results.append((res, base_name))


    meth_segs = dd(dict)

    for res, base_name in results:
        meth_result, seg = res.get()

        if meth_result is None:
            continue

        seg_id = '%s:%d-%d' % seg[:3]

        seg_chrom, seg_start, seg_end, seg_name, seg_strand, read_count = map(str, seg)

        meth_segs[seg_id]['seg_id']     = seg_id
        meth_segs[seg_id]['seg_chrom']  = seg_chrom
        meth_segs[seg_id]['seg_start']  = seg_start
        meth_segs[seg_id]['seg_end']    = seg_end
        meth_segs[seg_id]['seg_name']   = seg_name
        meth_segs[seg_id]['seg_strand'] = seg_strand

        if seg_name == 'NA':
            meth_segs[seg_id]['seg_name'] = 'NoName'

        if seg_strand == 'NA':
            meth_segs[seg_id]['seg_strand'] = '.'

        for modname, meth_data in meth_result.items():
            no_calls = 0
            meth_calls = 0
            unmeth_calls = 0

            if -1 in meth_data:
                unmeth_calls = meth_data[-1]

            if 0 in meth_data:
                no_calls = meth_data[0]

            if 1 in meth_data:
                meth_calls = meth_data[1]

            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_meth_calls'] = meth_calls
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_unmeth_calls'] = unmeth_calls
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_no_calls'] = no_calls

            if meth_calls+unmeth_calls == 0:
                 meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_methfrac'] = 'NaN'
            else:
                meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_methfrac'] = meth_calls/float(meth_calls+unmeth_calls)

            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_readcount'] = read_count

    header = ['seg_id', 'seg_chrom', 'seg_start', 'seg_end', 'seg_name', 'seg_strand']

    for bn in base_names:
        for m in mod_names:
            for s in stats:
                header.append('%s_%s_%s' % (os.path.basename(bn), m, s))

    out.write('\t'.join(header) + '\n')

    for mseg in meth_segs:
        output = [] 
        for h in header:
            if h not in meth_segs[mseg]:
                logger.warning('no calls for sample %s in segment %s, skipped.' % (h, mseg))
                break

            output.append(str(meth_segs[mseg][h]))
        out.write('\t'.join(output) + '\n')

    out.close()


def segplot(args):
    '''
    strip plots / violin plots of segmented methylation data over categories by sample
    '''

    if args.ridge:
        v = sns.__version__.split('.')
        if int(v[1]) < 11:
            sys.exit('--ridge requires seaborn 0.11.2 or later')
        elif int(v[1]) == 11 and int(v[2]) < 2:
            sys.exit('--ridge requires seaborn 0.11.2 or later')

    data = pd.read_csv(args.segmeth, sep='\t', header=0, index_col=0)

    samples = []
    mods = []

    for c in data.columns:
        if c.endswith('_meth_calls'):
            sample_mod = c.replace('_meth_calls', '')
            sample = '_'.join(sample_mod.split('_')[:-1])
            mod = sample_mod.split('_')[-1]            
            samples.append(sample)
            if mod not in mods:
                mods.append(mod)

    logger.info('available samples: %s' % ','.join(samples))
    logger.info('available mods: %s' % ','.join(mods))

    if args.samples:
        user_samples = []
        for s in args.samples.split(','):
            if s not in samples:
                sys.exit('%s not in sample list!' % s)
            user_samples.append(s)

        samples = user_samples

    if args.mods:
        user_mods = []
        for m in args.mods.split(','):
            if m not in mods:
                sys.exit('%s not in mod list!' % m)
            user_mods.append(m)

        mods = user_mods

    samples_mods = []
    for s in samples:
        for m in mods:
            sample_mod = s + '_' + m
            if sample_mod not in samples_mods:
                samples_mods.append(sample_mod)

    logger.info('sample + mod permutations: %s' % ','.join(samples_mods))

    for s in samples:
        for m in mods:
            sm = s + '_' + m
            data[sm + '_methfrac'] = data[sm + '_meth_calls']/(data[sm + '_meth_calls']+data[sm + '_unmeth_calls'])

    useable = []

    for seg in data.index:
        use_seg = True

        for s in samples:
            for m in mods:
                sm = s + '_' + m

                if (data[sm + '_meth_calls'].loc[seg] + data[sm + '_unmeth_calls'].loc[seg]) < int(int(args.mincalls)):
                    use_seg = False
                    continue

                if data[sm+'_readcount'].loc[seg] < int(args.minreads):
                    use_seg = False
                    continue

        if use_seg:
            useable.append(seg)

    data = data.loc[useable]

    logger.info('useable sites: %d' % len(useable))

    if len(useable) < 1:
        sys.exit('exiting: no useable sites, possibly due to lack of coverage or lack of modification calls')

    plot_data = dd(dict)

    order = []

    for seg in data.index:
        for s in samples:
            for m in mods:
                sm = s + '_' + m
                uid = seg + ':' + sm
                plot_data[uid]['sample']  = sm
                plot_data[uid]['modbase'] = data[sm + '_methfrac'].loc[seg]
                plot_data[uid]['group']   = data['seg_name'].loc[seg]

                if args.ridge:
                    plot_data[uid]['samplegroup'] = sm + ' ' + plot_data[uid]['group']

                if plot_data[uid]['group'] not in order:
                    order.append(plot_data[uid]['group'])

    plot_data = pd.DataFrame.from_dict(plot_data).T
    plot_data = pd.DataFrame(plot_data.to_dict())

    basename = '.'.join(args.segmeth.split('.')[:-1])

    basename += '.mc%d' % int(args.mincalls)
    basename += '.mr%d' % int(args.minreads)

    plot_data.to_csv(basename+'.segplot_data.csv')
    logger.info('plot data written to %s.segplot_data.csv' % basename)

    pt_sz = int(args.pointsize)

    if args.categories is not None:
        order = args.categories.split(',')
        plot_data = plot_data[plot_data['group'].isin(order)]

    if args.violin:
        if args.group_by_annotation:
            sns_plot = sns.violinplot(x='group', y='modbase', data=plot_data, hue='sample', dodge=True, jitter=True, order=order, hue_order=samples_mods, palette=args.palette)
        else:
            sns_plot = sns.violinplot(x='sample', y='modbase', data=plot_data, hue='group', dodge=True, jitter=True, hue_order=order, palette=args.palette)

        basename += '.violin'

    elif args.ridge:
        sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

        sample_order = CategoricalDtype(samples_mods, ordered=True)
        plot_data['sample'] = plot_data['sample'].astype(sample_order)

        if args.categories:
            cat_order = CategoricalDtype(order, ordered=True)
            plot_data['group'] = plot_data['group'].astype(cat_order)

        plot_data = plot_data.sort_values(['sample','group'])

        if args.group_by_annotation:
            plot_data = plot_data.sort_values(['group','sample'])

        sns_plot = sns.FacetGrid(plot_data, row='samplegroup', hue='sample', aspect=15, height=.5, palette=args.palette)
        sns_plot.map(sns.kdeplot, 'modbase', bw_adjust=float(args.ridge_smoothing), clip_on=False, fill=True, alpha=float(args.ridge_alpha), linewidth=1.5)
        sns_plot.map(sns.kdeplot, 'modbase', clip_on=False, color='w', lw=2, bw_adjust=float(args.ridge_smoothing), alpha=float(args.ridge_alpha))
        sns_plot.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)
        sns_plot.figure.subplots_adjust(hspace=float(args.ridge_spacing))

        seen_samples = set()
        seen_groups = set()

        s_col = sns.color_palette(args.palette, n_colors=len(set(plot_data['sample'])))

        s = 0

        for i in range(len(set(plot_data['samplegroup']))):
            ax = sns_plot.axes[i,0]
            samplename, groupname = ax.title.get_text().split('=')[-1].strip().split()

            if args.group_by_annotation and groupname not in seen_groups:
                ax.text(0, .15, groupname, color='k', ha='left', va='center', transform=ax.transAxes)

            if args.group_by_annotation and samplename not in seen_samples:
                ax.text(0, -.19, samplename, size=10, color='k', ha='left', va='center', transform=ax.transAxes)

            if not args.group_by_annotation:
                ax.text(0, .15, samplename, weight='bold', color=s_col[s], ha='left', va='center', transform=ax.transAxes)
                s += 1
                if s == len(s_col):
                    s = 0

            if not args.group_by_annotation and groupname not in seen_groups:
                ax.text(0, -.19, groupname, size=10, color='k', ha='left', va='center', transform=ax.transAxes)


            seen_groups.add(groupname)
            seen_samples.add(samplename)

        sns_plot.set_titles("")
        sns_plot.set(yticks=[], ylabel="")
        sns_plot.despine(bottom=True, left=True)

        basename += '.ridge'

    else:
        if args.group_by_annotation:
            sns_plot = sns.stripplot(x='group', y='modbase', data=plot_data, hue='sample', dodge=True, jitter=True, size=pt_sz, order=order, hue_order=samples_mods, palette=args.palette)
        else:
            sns_plot = sns.stripplot(x='sample', y='modbase', data=plot_data, hue='group', dodge=True, jitter=True, size=pt_sz, hue_order=order, palette=args.palette)

    if args.group_by_annotation:
        basename += '.group_by_annotation'

    if not args.ridge:
        sns_plot.set_xlabel("")
        sns_plot.set_ylabel(args.ylabel)

        if args.tiltlabel:
            sns_plot.set_xticklabels(sns_plot.get_xticklabels(), rotation=45)

        sns_plot.set_ylim(float(args.ymin),float(args.ymax))

        if args.width is None:
            args.width = 1 + len(order) * len(samples_mods)
            logger.info('auto set --width: %d' % args.width)

    fig = sns_plot.figure

    if not args.ridge:
        fig.set_size_inches(float(args.width), float(args.height))

    outfn = basename

    if args.svg:
        outfn += '.segplot.svg'
    else:
        outfn += '.segplot.png'

    if args.outfile is not None:
        outfn = args.outfile
        if args.svg:
            if outfn.split('.')[-1] != 'svg':
                logger.warning('warning: %s does not have extension .svg, appending')
                outfn += '.svg'

    fig.savefig(outfn, bbox_inches='tight')
    logger.info('plot saved to %s' % outfn)


def locus(args):
    '''
    plot methylation profile of a region in CpG space
    '''
    # set up

    assert ':' in args.interval
    assert '-' in args.interval

    if args.panelratios:
        if args.plot_coverage:
            if len(args.panelratios.split(',')) != 6:
                logger.warning('locus -p/--panelratios requires 6 terms e.g. --panelratios 1,5,1,3,3,3') 
                sys.exit(1)
        else:
            if len(args.panelratios.split(',')) != 5:
                logger.warning('locus -p/--panelratios requires 5 terms e.g. --panelratios 1,5,1,3,3') 
                sys.exit(1)

    if args.motif is not None:
        assert iupac(args.motif)

    chrom, pos= args.interval.split(':')
    elt_start, elt_end = pos.split('-')

    elt_start = int(elt_start.replace(',',''))
    elt_end = int(elt_end.replace(',',''))
    
    data = dd(list)
    user_colours = {}

    markers = ['.',',','o','v','^','<','>','8','s','p','P','*','h','H','X','D','d']

    if args.readmarker not in markers:
        sys.exit('%s is not a valid marker type, valid types: %s' % (args.readmarker, ','.join(markers)))

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams')

    methbam = False

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both')

        with open(args.data) as _:
            for line in _:
                c = line.strip().split()
                if len(c) < 2:
                    logger.warning("required fields for -d/--data are: .bam file and methylation .db (generated by methylartist)")
                    sys.exit()

                bam, meth_db = c[:2]
                for m_db in meth_db.split(','):
                    data[bam].append(m_db)

                if len(c) == 3:
                    user_colours[bam] = c[2]

    if args.motif is not None:
        if len(args.motif) != int(args.motifsize):
            logger.warning('motif size (set with --motifsize) %d does not match length of --motif (%s), changed --motifsize' % (int(args.motifsize), args.motif))
            args.motifsize = len(args.motif)

    if args.bams is not None:
        logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))

        if None in (args.ref, args.motif):
            logger.warning('--ref and --motif are required when using --bams')
            sys.exit(1)

        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')

        elif is_bam(args.bams):
            bams = args.bams.split(',')

        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    c = line.strip().split()
                    if len(c) == 1:
                        bams.append(c[0])
                    elif len(c) == 2:
                        bams.append(c[0])
                        user_colours[bams[-1]] = c[1]
                    else:
                        sys.exit(f'unparsable line in {args.bams}: {line.strip()}')

        for bam in bams:
            if ':' in bam:
                bam, ucol = bam.split(':')
                user_colours[bam] = ucol

            if not os.path.exists(bam+'.bai'):
                sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

    # table for plotting

    meth_table = dd(dict)

    sample_order = []

    reads = {}
    orig_bam = {}

    use_mods = None
    if args.mods:
        use_mods = args.mods.split(',')

    phases = {}

    if args.phased:
        for bam in data:
            phased_reads = get_phased_reads(bam, chrom, elt_start, elt_end, tag_untagged=args.include_unphased, min_mapq=int(args.min_mapq), HP_only=args.ignore_ps)
            phases[bam] = list(set([p for p in phased_reads.values() if p]))
            logger.info('phases for bam %s: %s' % (bam, ','.join(phases[bam])))

    for bam, meth_dbs in data.items():
        if meth_dbs is None:
            mods = mods_methbam(bam)
            logger.info('found mods: %s in bam %s' % (','.join(mods), bam))

        else:
            for meth_db in meth_dbs:
                mods = sorted(get_modnames(meth_db))
                logger.info('found mods: %s in db %s' % (','.join(mods), meth_db))

        for mod in mods:
            if use_mods:
                if mod not in use_mods:
                    logger.info('skipping %s, not specified in -m/--mods %s' % (mod, args.mods))
                    continue
            else:
                use_mods = mods

            if args.phased:
                for phase in phases[bam]:
                    bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + phase + '.' + mod
                    orig_bam[bamname] = bam
                    reads[bamname] = get_meth_locus(args, bam, meth_dbs, mod, phase=phase, methbam=methbam, HP_only=args.ignore_ps, restrict_motif=args.motif, restrict_ref=args.ref)

                    for name, read in reads[bamname].items():
                        for loc in read.llrs.keys():
                            uuid = str(uuid4())
                            meth_table[uuid]['loc'] = loc
                            meth_table[uuid]['modstat'] = read.llrs[loc]
                            meth_table[uuid]['read'] = name
                            meth_table[uuid]['sample'] = bamname
                            meth_table[uuid]['call'] = read.meth_calls[loc]

                            if bamname not in sample_order:
                                sample_order.append(bamname)

            else:
                bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + mod
                orig_bam[bamname] = bam
                reads[bamname] = get_meth_locus(args, bam, meth_dbs, mod, methbam=methbam, restrict_motif=args.motif, restrict_ref=args.ref)

                for name, read in reads[bamname].items():
                    for loc in read.llrs.keys():
                        uuid = str(uuid4())
                        meth_table[uuid]['loc'] = loc
                        meth_table[uuid]['modstat'] = read.llrs[loc]
                        meth_table[uuid]['read'] = name
                        meth_table[uuid]['sample'] = bamname
                        meth_table[uuid]['call'] = read.meth_calls[loc]

                        if bamname not in sample_order:
                            sample_order.append(bamname)

    meth_table = pd.DataFrame.from_dict(meth_table).T

    if 'loc' not in meth_table:
        sys.exit('%s: insufficient coverage for plot' % args.interval)

    meth_table['loc'] = pd.to_numeric(meth_table['loc'])
    meth_table['modstat'] = pd.to_numeric(meth_table['modstat'])

    meth_table['orig_loc'] = meth_table['loc']
    meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

    if len(meth_table['orig_loc']) == 0:
        sys.exit('%s: insufficient coverage for plot' % args.interval)        

    # optional mincall filter

    if int(args.mincalls) > 0:
        drop_ix = []

        logger.info('assessing coverage per --mincalls, per-read sites to consider: %d' % len(meth_table['loc']))
        tick = 5000
        for i, new_loc in enumerate(meth_table['loc'], 1):
            for sample in sample_order:
                call_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] == new_loc)])

                if call_count < int(args.mincalls):
                    drop_ix += list(meth_table.loc[(meth_table['loc'] == new_loc)].index)
            
            if i % tick == 0:
                logger.info('%s %s: processed %d sites' % (args.data, args.interval, i))
        
        drop_ix = list(set(drop_ix))
        logger.info('dropping %d positions' % len(drop_ix))
        meth_table.drop(drop_ix, inplace=True)

        meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

        if len(meth_table['orig_loc']) == 0:
            sys.exit('%s: insufficient coverage for plot' % args.interval)

    # create mod space

    coord_to_cpg = {}
    for orig_loc, new_loc in zip(meth_table['orig_loc'], meth_table['loc']):
        coord_to_cpg[orig_loc] = new_loc

    # calibrate plotting parameters
    logger.info('region size: %d' % (elt_end-elt_start))

    modspace_n = len(list(set(coord_to_cpg.values())))
    logger.info('mod space positions: %d' % modspace_n)

    if args.modspace is not None:
        logger.info('user set --modspace %d, mod_n:ticks: %.3f' % (int(args.modspace), modspace_n/float(args.modspace)))
        args.modspace = int(args.modspace)
    else:
        if modspace_n <= 1200:
            args.modspace = 1
        else:
            args.modspace = round(modspace_n/1200)

        logger.info('auto set --modspace %d' % args.modspace)

    if args.smoothwindowsize is not None:
        logger.info('user set --smoothwindowsize %d, mod_n:smoothwindowsize: %.3f' % (int(args.smoothwindowsize), modspace_n/float(args.smoothwindowsize)))
        args.smoothwindowsize = int(args.smoothwindowsize)

        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1
            logger.info('-s/--smoothwindowsize must be an even integer, adjusted value: %d' % args.smoothwindowsize)

    else:
        args.smoothwindowsize = round(0.0167*modspace_n + 18)

        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1

        logger.info('auto set --smoothwindowsize %d' % args.smoothwindowsize)

    # data for coverage plot

    cover_bams = {}

    if args.plot_coverage:
        cov_bams = []

        if args.plot_coverage.endswith('.bam'):
            cov_bams = args.plot_coverage.split(',')

        if args.plot_coverage.endswith('.txt'):
            with open(args.plot_coverage) as _:
                for fn in _:
                    cov_bams.append(fn.strip())

        for bam_fn in cov_bams:
            logger.info('gather coverage from %s...' % bam_fn)

            meth_seg_starts = deepcopy(np.asarray(meth_table['orig_loc']))
            meth_seg_starts += elt_start

            cover = bam_pileupcover(bam_fn, chrom, meth_seg_starts, meth_seg_starts, procs=int(args.coverprocs), log=args.logcover)

            c_name = '.'.join(os.path.basename(bam_fn).split('.')[:-1])

            cover_table = dd(dict)
            sorted_locs = sorted(list(set(meth_table['loc'])))

            for c, loc in zip(cover, sorted_locs):
                uuid = str(uuid4())
                cover_table[uuid]['loc'] = loc
                cover_table[uuid]['sample'] = c_name
                cover_table[uuid]['cover'] = c
            
            cover_table = pd.DataFrame.from_dict(cover_table).T

            windowed_cover = slide_window_cover(cover_table, c_name, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))
            cover_bams[c_name] = smooth(np.asarray(list(windowed_cover.values())), window_len=int(args.smoothwindowsize), window=args.smoothfunc)


    # highlights

    h_start = []
    h_end = []
    h_cpg_start = []
    h_cpg_end = []
    h_colors = []

    if args.highlight:
        h_colors = sns.color_palette(args.highlightpalette, n_colors=len(args.highlight.split(',')))

    if args.highlight:
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_s, h_e = map(int, h.split('-'))
            h_start.append(h_s - elt_start)
            h_end.append(h_e - elt_start)

            h_cpg_start.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_start[-1]))])
            h_cpg_end.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_end[-1]))])

    if args.highlight_bed:
        need_colour = 0

        with open(args.highlight_bed) as h_bed:
            for line in h_bed:
                c = line.strip().split()
                assert len(c) >= 3, 'malformed line in --highlight_bed: %s' % line.strip()

                if c[0] != chrom:
                    continue

                h_s, h_e = map(int, c[1:3])

                assert h_s < h_e

                if h_s < elt_start:
                    continue

                if h_e > elt_end:
                    continue

                colour = None

                if len(c) > 3:
                    colour = c[3]
                else:
                    need_colour += 1
                
                h_start.append(h_s - elt_start)
                h_end.append(h_e - elt_start)

                h_cpg_start.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_start[-1]))])
                h_cpg_end.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_end[-1]))])

                h_colors.append(colour)
            
            logger.info('found %d within-interval highlight regions in %s' % (len(h_start), args.highlight_bed))

        more_colours = sns.color_palette(args.highlightpalette, n_colors=need_colour)

        for i, c in enumerate(h_colors):
            if c is None:
                h_colors[i] = more_colours.pop()

    # mask

    readmask = []
    if args.readmask:
        for ivl in args.readmask.split(','):
            if ':' in ivl:
                ivl = ivl.split(':')[1]
            assert '-' in ivl, 'malformed --readmask interval(s): %s' % args.readmask

            readmask.append(list(map(int, ivl.split('-'))))

    # set up plot
    p_ratios=[1,5,1,3,3]
    if args.plot_coverage:
        p_ratios.append(3)

    fig = plt.figure()
    gs = gridspec.GridSpec(len(p_ratios),1,height_ratios=p_ratios, hspace=0)

    img_w = 16
    img_h = 8

    if args.width:
        img_w = float(args.width)

    if args.height:
        img_h = float(args.height)

    fig.set_size_inches(img_w, img_h)

    if args.panelratios:
        p_ratios = list(map(int, args.panelratios.split(',')))
        if args.plot_coverage and len(p_ratios) < 5:
            logger.warning('--panelratios must have 5 values if used with --plot_coverage')
            sys.exit(1)

        gs = gridspec.GridSpec(len(p_ratios),1,height_ratios=p_ratios, hspace=0)

    sample_color = {}
    for i, sample in enumerate(sample_order):
        sample_color[sample] = sns.color_palette(args.samplepalette, n_colors=len(sample_order))[i]

    for sample in sample_color:
        if orig_bam[sample] in user_colours:
            sample_color[sample] = user_colours[orig_bam[sample]]

    if args.color_by_hp:
        if args.phased:
            hp_color = {}
            for hp in ('1','2'):
                hp_color[hp] = sns.color_palette(args.samplepalette, n_colors=2)[int(hp)-1]
            
            for sample in sample_color:
                hp = sample.split('.')[-2]
                assert hp in ('1','2')
                sample_color[sample] = hp_color[hp]

        else:
            logger.warning('--color_by_hp has no effect without --phase')

    cover_color = {}

    if args.plot_coverage:
        for i, c_name in enumerate(cover_bams.keys()):
            cover_color[c_name] = sns.color_palette(args.coverpalette, n_colors=len(cover_bams))[i]

    # plot genes

    ax0 = plt.subplot(gs[0])

    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.xaxis.set_ticks_position('top')

    gtf = None
    genes = [] 

    if args.gtf is not None:
        logger.info('building genes plot...')
        gtf = pysam.Tabixfile(args.gtf)
        genes = build_genes(gtf, chrom, elt_start, elt_end)

    exon_patches = []
    tx_lines = []

    genes_of_interest = []

    if args.genes is not None:
        genes_of_interest = args.genes.strip().split(',')

    genemap = dd(Intersecter)

    if genes_of_interest:
        new_genes = {}
        for ensg in genes:
            if genes[ensg].name in genes_of_interest:
                new_genes[ensg] = genes[ensg]
        
        genes = new_genes

    gene_colours = sns.color_palette(args.genepalette, n_colors=len(genes))

    y = 1

    for i, ensg in enumerate(genes):
        if not genes[ensg].has_tx():
            logger.warning('no transcript for gene: %s' % genes[ensg].name)
            continue

        logger.info('gene in region: %s' % genes[ensg].name)

        y = 1

        while genemap[y].find(genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start):
            y += 1.5

        tx_lines.append(matplotlib.lines.Line2D([genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start], [0.4+y, 0.4+y], color=gene_colours[i], zorder=2))

        genes[ensg].merge_exons()
        for exon_start, exon_end in genes[ensg].exons:
            exon_len = exon_end - exon_start
            exon_patches.append(matplotlib.patches.Rectangle([exon_start-elt_start, y], exon_len, 0.8, edgecolor=gene_colours[i], facecolor=gene_colours[i], zorder=3))

        if args.labelgenes:
            label = genes[ensg].name

            if genes[ensg].strand in ('+', '-'):
                if genes[ensg].strand == '+':
                    label += '>>'
                
                else:
                    label = '<<' + label

            lg_x = max(genes[ensg].tx_start-elt_start, 0)

            nudge_up = 0.05
            if genes[ensg].tx_start-elt_start < 0:
                nudge_up = 0.7

            gtxt = ax0.text(lg_x, y+nudge_up, label, zorder=4, fontsize='small')
            bb_w = gtxt.get_tightbbox(renderer=fig.canvas.get_renderer()).width
            fig_w = fig.get_size_inches()[0]*fig.dpi
            txt_w = bb_w/fig_w*(elt_end-elt_start)
            gtxt.set_x(lg_x-txt_w*1.42)

            genemap[y].add_interval(Interval(genes[ensg].tx_start-elt_start-(txt_w*1.5), genes[ensg].tx_end-elt_start+(txt_w*1.5)))

        else:
            genemap[y].add_interval(Interval(genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start))

    gene_height = max(4, len(genemap))

    if args.labelgenes:
        gene_height = max(8, len(genemap))

    if args.bed:
        logger.info('loading annotations from %s' % args.bed)
        gene_height += 3

        y += 1

        bed_annotations = get_bed_annotations(args.bed, chrom, elt_start, elt_end)
        bed_colours = sns.color_palette(args.genepalette, n_colors=len(bed_annotations))

        for i, ann in enumerate(bed_annotations):
            bed_colour = bed_colours[i]

            if ann.colour is not None:
                bed_colour = ann.colour

            exon_patches.append(matplotlib.patches.Rectangle([ann.start-elt_start, y], (ann.end-ann.start), 1.0, edgecolor=bed_colour, facecolor=bed_colour, zorder=3))

            if ann.label is not None:
                lg_x  = max(ann.start-elt_start, 0)
                label = ann.label

                if ann.strand is not None:
                    if ann.strand == '+':
                        label += '>>'
                    
                    else:
                        label = '<<' + label

                gtxt  = ax0.text(lg_x, y, label, bbox=dict(boxstyle="round,pad=0.3", fc="lavender", alpha=0.5, lw=0), zorder=4)

    ax0.set_ylim(0, gene_height)
    ax0.set_yticks([])

    highlight_patches = []

    if args.highlight_alpha:
        a = float(args.highlight_alpha)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_start):
                h_e = h_end[h]
                h_color = h_colors[h]
                ymin = min(ax0.get_ybound())
                yheight = max(ax0.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax0.add_patch(h)

    for p in exon_patches:
        ax0.add_patch(p)

    for tx in tx_lines:
        ax0.add_line(tx)

    # per-read plot

    logger.info('building read alignment plot...')
    ax1 = plt.subplot(gs[1])
    ax1.set_xticks([])
    ax1.set_yticks([])

    readstack = dd(list)

    max_y  = 1
    pack_y = 1

    masked_count = 0

    for bamname in reads:
        fetch_reads_bam = pysam.AlignmentFile(orig_bam[bamname])
        pos_cache = {}

        for read in fetch_reads_bam.fetch(chrom, elt_start, elt_end):
            if read.mapq < int(args.min_mapq):
                continue

            if read.is_supplementary or read.is_secondary or read.is_duplicate:
                if not args.allreads:
                    continue

            masked = False
            if len(readmask) > 0:
                for mask_start, mask_end in readmask:
                    if read.reference_start >= mask_start and read.reference_end <= mask_end:
                        logger.debug('masked read: %s' % read.query_name)
                        masked_count += 1
                        masked = True
            
            if masked:
                continue

            if read.query_name not in pos_cache:
                pos_cache[read.query_name] = ([read.reference_start], [read.reference_end])
            else:
                pos_cache[read.query_name][0].append(read.reference_start)
                pos_cache[read.query_name][1].append(read.reference_end)

        for readname, read in reads[bamname].items():
            if readname not in pos_cache:
                logger.debug('read %s not found in %s (skipped)' % (readname, bamname))
                continue

            if read.call_count == 0:
                continue

            read.ypos = max_y
            max_y += 1

            read.starts, read.ends = pos_cache[readname]
            readstack[bamname].append(read)

        fetch_reads_bam.close()

    if args.readmask:
        logger.info('masked %d reads due to --readmask' % masked_count)

    # read packing

    for bamname in readstack:
        reads = readstack[bamname]

        y = dd(list)

        for read in readstack[bamname]:
            y[read.ypos].append(read)

        for p in y:
            for q in y:
                if p == q:
                    continue

                for read_q in y[q]:
                        move = True

                        for read_p in y[p]:
                            if read_q.overlap(read_p):
                                move = False

                        if move:
                            y[p].append(read_q)
                            y[q].remove(read_q)

        for p in y:
            if len(y[p]) > 0:
                for read in sorted(y[p], key=lambda r: min(r.starts)):
                    read.ypos = pack_y
                pack_y += 1

    ax1.set_ylim(0,pack_y+1)

    for bamname in readstack:
        for read in readstack[bamname]:
            rm = args.readmarker
            ms = float(args.readmarkersize)
            lw = float(args.readlinewidth)
            la = float(args.readlinealpha)
            ma = float(args.markeralpha)
            uec = sample_color[bamname]

            if args.readopenmarkeredgecolor is not None:
                uec = args.readopenmarkeredgecolor

            for call_pos, call in read.meth_calls.items():
                if call == -1:
                    ax1.plot(call_pos, read.ypos, marker=rm, fillstyle='full', mec=uec, mfc='white', markersize=ms, zorder=3, alpha=ma)

                if call == 1:
                    ax1.plot(call_pos, read.ypos, marker=rm, fillstyle='full', mec='black', mfc='black', markersize=ms, zorder=3, alpha=ma)

            for i in range(len(read.starts)):
                readline_start = max(read.starts[i], elt_start) - elt_start
                readline_end   = min(read.ends[i], elt_end) - elt_start

                ax1.add_line(matplotlib.lines.Line2D([readline_start, readline_end], [read.ypos, read.ypos], lw=lw, zorder=2, color=sample_color[bamname], alpha=la))

    highlight_patches = []

    if args.highlight_alpha:
        a = float(args.highlight_alpha)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_start):
                h_e = h_end[h]
                h_color = h_colors[h]
                ymin = min(ax1.get_ybound())
                yheight = max(ax1.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax1.add_patch(h)

    # plot correspondence between genome space and cpg space

    logger.info('building mod-space plot...')
    ax2 = plt.subplot(gs[2])
    ax3 = ax2.twiny()

    ax2.set_ylim(0,10)
    ax2.set_yticklabels([])

    x1 = []
    x2 = []

    step = int(args.modspace)

    orig_positions = list(set(meth_table['orig_loc']))

    for i, x in enumerate(orig_positions):
        if i in (0, len(orig_positions)-1):
            x2.append(x)
            x1.append(coord_to_cpg[x])

        elif i % step == 0:
            x2.append(x)
            x1.append(coord_to_cpg[x])

    ax2.vlines(x1, 0, 1, color='#777777', zorder=1)
    ax3.vlines(x2, 9, 10, color='#777777', zorder=1)

    if args.highlight or args.highlight_bed:
        for i in range(len(h_start)):
            orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), h_end[i]-h_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=2)
            cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), h_cpg_end[i]-h_cpg_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=3)

            ax3.add_patch(orig_highlight_box)
            ax2.add_patch(cpg_highlight_box)

    for x1_x, x2_x in zip(x1, x2):
        link_end1 = (x1_x, 1)
        link_end2 = (x2_x, 9)

        l_col = '#777777'

        for i in range(len(h_start)):
            if x2_x >= h_start[i] and x2_x <= h_end[i]:
                l_col = h_colors[i]

        con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=l_col)
        ax3.add_artist(con)

    ax0.set_xlim(ax3.get_xlim()) # sync axes between orig coords and gtf plot
    ax1.set_xlim(ax3.get_xlim())
    ax2.set_xticks([])
    ax3.set_xticks([])

    n_ticks = int(args.nticks) + 1
    tick_interval = (elt_end-elt_start)/n_ticks
    tick_list = list(range(int(ax0.get_xlim()[0]), int(ax0.get_xlim()[1]), int(tick_interval)))

    revised_tick_list = []
    for t in tick_list:
        if t < 0:
            revised_tick_list.append(0)
        else:
            revised_tick_list.append(t)

    tick_list = sorted(list(set(revised_tick_list)))

    xt_labels = [str(int(t+elt_start)) for t in tick_list]
    xt_labels[0] = chrom

    ax0.set_xticks(tick_list)

    if n_ticks > 11:
        ax0.set_xticklabels(xt_labels, rotation=45)
    else:
        ax0.set_xticklabels(xt_labels)

    # llr plot

    logger.info('building llr plot...')
    ax4 = plt.subplot(gs[3])
    ax4.set_xticks([])
    ax4.yaxis.tick_right()

    ax4.axhline(y=0, c='#bbbbbb', linestyle='--',lw=1)

    for mod in use_mods:
        upper = 1.0
        lower = 0.0

        if not methbam:
            upper, lower = get_cutoffs(list(data.values())[0][0], mod)

        ax4.axhline(y=upper, c='k', linestyle='--',lw=1)
        ax4.axhline(y=lower, c='k', linestyle='--',lw=1)

    ax4 = sns.lineplot(x='loc', y='modstat', hue='sample', data=meth_table, palette=sample_color, zorder=2)
    
    if args.statname:
        ax4.set_ylabel(args.statname)

    ax4.set_xlim(ax2.get_xlim())

    highlight_patches = []

    if args.highlight_alpha:
        a = float(args.highlight_alpha)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_cpg_start):
                h_e = h_cpg_end[h]
                h_color = h_colors[h]
                ymin = min(ax4.get_ybound())
                yheight = max(ax4.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax4.add_patch(h)

    # meth frac plot

    logger.info('building meth frac plot...')
    ax5 = plt.subplot(gs[4])

    order_stack = 2

    smoothalpha = float(args.smoothalpha)

    if smoothalpha > 1.0 or smoothalpha < 0.0:
        logger.warning('--smoothalpha must be between 0 and 1, set to 1.0')
        smoothalpha = 1.0

    for sample in sample_order:
        windowed_methfrac, meth_n = slide_window(meth_table, sample, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))

        smoothed_methfrac = smooth(np.asarray(list(windowed_methfrac.values())), window_len=int(args.smoothwindowsize), window=args.smoothfunc)

        masked_segs = mask_methfrac(list(meth_n.values()), cutoff=args.maskcutoff)

        frac_masked = len(list(itertools.chain(*masked_segs))) / len(meth_n.values())

        logger.info('%s:%d-%d (%s), sample %s fraction masked: %.3f' % (chrom, elt_start, elt_end, ''.join(use_mods), sample, frac_masked))

        if frac_masked > float(args.maxmaskedfrac):
            logger.warning('%s:%d-%d (%s), skip sample %s due to --maxmaskedfrac %.3f' % (chrom, elt_start, elt_end, ''.join(use_mods), sample, float(args.maxmaskedfrac)))
            continue

        if frac_masked > 0.1 and args.bams is not None and args.ref is None:
            logger.warning('*** WARNING: specifying a reference genome (indexed via samtools faidx) with --restrict_ref is strongly recommended when using mod .bams ***')

        ax5.plot(list(windowed_methfrac.keys()), smoothed_methfrac, marker='', lw=4, color=sample_color[sample], alpha=smoothalpha)

        order_stack += 1

        if not args.nomask:
            for seg in masked_segs:
                if len(seg) > 2:
                    mf_seg = np.asarray(smoothed_methfrac)[seg]
                    pos_seg = np.asarray(list(windowed_methfrac.keys()))[seg]
                
                    ax5.plot(pos_seg, mf_seg, marker='', lw=4, color='#ffffff', alpha=0.8, zorder=order_stack)

                    order_stack += 1

    ax4.set_zorder(order_stack)  # adjust z-level of legend after smoothed plot is finished

    if args.color_by_hp:
        handles, labels = ax4.get_legend_handles_labels()
        phases = list(set([label.split('.')[-2] for label in labels]))

        found_phases = {}

        new_labels = []
        new_handles = []

        if args.phase_labels:
            if ':' not in args.phase_labels:
                sys.exit('incorrect syntax: %s' % args.phase_labels)

            phase_labels = dict([pl.split(':') for pl in args.phase_labels.split(',')])

        for i, label in enumerate(labels):
            hp = label.split('.')[-2]
            if hp not in found_phases:
                phase_name = hp

                if args.phase_labels:
                    if hp in phase_labels:
                        phase_name = phase_labels[hp]

                new_labels.append(phase_name)
                new_handles.append(handles[i])
                
            found_phases[hp] = True

        ax4.legend(new_handles, new_labels)

    ax5.set_xlim(ax2.get_xlim())
    ax5.set_ylim((float(args.ymin),float(args.ymax)))

    highlight_patches = []

    if args.highlight_alpha:
        a = float(args.highlight_alpha)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_cpg_start):
                h_e = h_cpg_end[h]
                h_color = h_colors[h]
                ymin = min(ax5.get_ybound())
                yheight = max(ax5.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax5.add_patch(h)

    # coverage plot (optional)

    if args.plot_coverage:
        logger.info('building coverage plot...')
        ax6 = plt.subplot(gs[5])

        for sample in cover_bams:
            ax6.plot(list(windowed_methfrac.keys()), cover_bams[sample], marker='', lw=4, color=cover_color[sample], zorder=3, label=sample, alpha=smoothalpha)

        ax6.legend()
        ax6.set_xlim(ax2.get_xlim())

        highlight_patches = []

        if args.highlight_alpha:
            a = float(args.highlight_alpha)
            if args.highlight or args.highlight_bed:
                for h, h_s in enumerate(h_cpg_start):
                    h_e = h_cpg_end[h]
                    h_color = h_colors[h]
                    ymin = min(ax6.get_ybound())
                    yheight = max(ax6.get_ybound())-ymin
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

        for h in highlight_patches:
            ax6.add_patch(h)

    fn_prefix = '.%s_%d_%d.%s' % (chrom, elt_start, elt_end, ''.join(use_mods))

    if methbam:
        fn_prefix = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1]) + fn_prefix
        if len(args.bams.split(',')) > 1:
            fn_prefix += '.cohort'
    else:
        fn_prefix = '.'.join(os.path.basename(args.data).split('.')[:-1]) + fn_prefix

    if args.genes is not None:
        fn_prefix = '_'.join(args.genes.split(',')) + '.' + fn_prefix

    if args.phased:
        fn_prefix += '.phased'

    param_str = '.ms%d.smw%d' % (args.modspace, int(args.smoothwindowsize))

    fn_prefix += param_str

    if args.max_read_density:
        fn_prefix += '.mrd%.2f' % float(args.max_read_density)

    if int(args.mincalls) > 0:
        fn_prefix += '.mc%d' % int(args.mincalls)

    outfn = fn_prefix

    if args.svg:
        outfn += '.locus.meth.svg'
    else:
        outfn += '.locus.meth.png'

    if args.outfile is not None:
        outfn = args.outfile
        if args.svg:
            if outfn.split('.')[-1] != 'svg':
                logger.warning('warning: %s does not have extension .svg, appending')
                outfn += '.svg'

    fig.savefig(outfn, bbox_inches='tight')
    logger.info('plot saved to %s' % outfn)


def region(args):
    '''
    plotting function for windowed view of larger regions
    '''

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos = args.interval.split(':')
    start, end = pos.split('-')

    start = int(start.replace(',',''))
    end = int(end.replace(',',''))

    assert start < end

    if args.motif is not None:
        assert iupac(args.motif)

    if end-start < 500000:
        logger.warning('locus smaller than 0.5 Mbp, "methylartist locus" may yield better results')

    scale_width = 1.0

    if args.scale_fullwidth:
        scale_width = (end-start)/float(args.scale_fullwidth)
        logger.info('scale width to %.3f' % scale_width)

    ref = pysam.FastaFile(args.ref)
    motifs = iupac(args.motif)
    region_seq = ref.fetch(chrom, start, end).upper()
    
    motif_count = 0
    for motif in motifs:
        motif_count += region_seq.count(motif) + region_seq.count(rc(motif))

    if args.windows is None:
        args.windows = round(motif_count / 30)
        logger.info('set window count to %d' % args.windows)

    args.windows = int(args.windows)

    if args.windows < 500:
        args.windows = 500
        logger.info('resetting windows to a minimum of 500')

    motifs_per_window = int(motif_count / int(args.windows))

    if motifs_per_window == 0:
        motifs_per_window = 1

    logger.info('motif count: %d, per window: %d' % (motif_count, motifs_per_window))

    w_starts = [start]
    w_ends = []

    i = 0
    motif_count = 0

    while i < len(region_seq) - len(motif):
        site_fwd = region_seq[i:i+len(motif)]
        site_rev = rc(site_fwd)

        for motif in motifs:
            if motif in (site_fwd, site_rev):
                motif_count += 1

        i += 1

        if motif_count == motifs_per_window:
            w_starts.append(start+i)
            w_ends.append(start+i)
            motif_count = 0

    w_ends.append(end)

    assert len(w_starts) == len(w_ends)
    logger.info('using %d windows normalised for %s content' % (len(w_starts), args.motif))

    if args.smoothwindowsize is None:
        w = len(w_starts)
        args.smoothwindowsize = round(0.02*w + 4)
        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1
            
        logger.info('set --smoothwindowsize to %d' % args.smoothwindowsize)

    args.smoothwindowsize = int(args.smoothwindowsize)

    if args.modspace is None:
        args.modspace = round(len(w_starts)/300)
        if args.modspace == 0:
            args.modspace = 1

        logger.info('set --modspace to %d' % args.modspace)

    args.modspace = int(args.modspace)

    data = dd(list)
    mods = []
    user_colours = {}

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams')

    methbam = False

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both')

        with open(args.data) as _:
            for line in _:
                c = line.strip().split()
                if len(c) < 2:
                    logger.warning("required fields for -d/--data are: .bam file and methylation .db (generated by methylartist)")
                    sys.exit()

                bam, meth_db = c[:2]
                for m_db in meth_db.split(','):
                    data[bam].append(m_db)
                    mods += sorted(get_modnames(m_db))

                if len(c) == 3:
                    user_colours[bam] = c[2]

    if args.bams is not None:
        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')
        
        elif is_bam(args.bams):
            bams = args.bams.split(',')

        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    c = line.strip().split()
                    if len(c) == 1:
                        bams.append(c[0])
                    elif len(c) == 2:
                        bams.append(c[0])
                        user_colours[bams[-1]] = c[1]
                    else:
                        sys.exit(f'unparsable line in {args.bams}: {line.strip()}')

        for bam in bams:
            if ':' in bam:
                bam, ucol = bam.split(':')
                user_colours[bam] = ucol

            if not os.path.exists(bam+'.bai'):
                sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

        mods = mods_methbam(bam)

    mods = list(set(mods))

    logger.info('found mods: %s' % ','.join(mods))

    if args.mods:
        for mod in args.mods.split(','):
            assert mod in mods, 'mod %s not found' % mod
        mods = args.mods.split(',')

    logger.info('using mods %s' % ','.join(mods))

    reads = {}
    orig_bam = {}

    phases = [None]
    if args.phased:
        phases = ['1','2']

    for phase in phases:
        for bam, meth_dbs in data.items():
            for mod in mods:
                bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + mod

                if args.phased:
                    bamname += '.' + phase

                orig_bam[bamname] = bam
                reads[bamname] = get_meth_locus(args, bam, meth_dbs, mod, phase=phase, methbam=methbam, restrict_motif=args.motif, restrict_ref=args.ref)

    pool = mp.Pool(processes=int(args.procs))

    meth_segs = dd(dict)
    sample_names = {}
    shallow_windows = dd(list)
    min_window_calls = int(args.min_window_calls)

    for phase in phases:
        results = []

        for seg_start, seg_end in zip(w_starts, w_ends):
            for bam_fn, meth_dbs in data.items():
                seg_strand = '.'
                seg_name = '.'.join(os.path.basename(bam_fn).split('.')[:-1])
                res = pool.apply_async(get_segmeth_calls, [args, bam_fn, mods, meth_dbs, chrom, seg_start, seg_end, seg_name, seg_strand, phase, methbam])
                results.append(res)

        logger.info('parsing segments (phase %s)...' % str(phase))

        for res in results:
            meth_result, seg = res.get()
            if meth_result is None:
                continue

            seg_id = '%s:%d-%d' % seg[:3]

            seg_chrom, seg_start, seg_end, seg_name, seg_strand = map(str, seg[:-1])

            meth_segs[seg_id]['seg_id']     = seg_id
            meth_segs[seg_id]['seg_chrom']  = seg_chrom
            meth_segs[seg_id]['seg_start']  = seg_start
            meth_segs[seg_id]['seg_end']    = seg_end
            meth_segs[seg_id]['seg_name']   = seg_name
            meth_segs[seg_id]['seg_strand'] = seg_strand

            for modname, meth_data in meth_result.items():
                no_calls = 0
                meth_calls = 0
                unmeth_calls = 0

                if -1 in meth_data:
                    unmeth_calls = meth_data[-1]

                if 0 in meth_data:
                    no_calls = meth_data[0]

                if 1 in meth_data:
                    meth_calls = meth_data[1]

                if meth_calls + unmeth_calls < min_window_calls:
                    if seg_id not in shallow_windows[modname]:
                        shallow_windows[modname].append(seg_id) # TODO phase aware

                sample_name = seg_name + '.' + modname
                
                if args.phased:
                    sample_name += '.' + phase

                sample_names[sample_name] = True

                meth_segs[seg_id][sample_name + '_meth_calls'] = meth_calls
                meth_segs[seg_id][sample_name + '_unmeth_calls'] = unmeth_calls

                if meth_calls + unmeth_calls == 0:
                    meth_segs[seg_id][sample_name + '_frac'] = 0 # changed from NaN
                else:
                    meth_segs[seg_id][sample_name + '_frac'] = meth_calls/float(meth_calls+unmeth_calls)

    for mod in mods: # TODO phase aware
        shallow_frac = len(shallow_windows[mod])/len(w_starts)*100.0
        if shallow_frac > 0.0:
            logger.info('%.2f percent of windows for mod %s had less than %d calls' % (shallow_frac, mod, min_window_calls))
            if shallow_frac > float(args.maxuncovered):
                sys.exit('greater than %.2f windows are uncovered, aborting.' % float(args.maxuncovered))

    deleted_segs = 0
    for mod in mods: # TODO phase aware
        for seg_id in shallow_windows[mod]:
            if seg_id in meth_segs:
                del meth_segs[seg_id]
                deleted_segs += 1

    if deleted_segs > 0:
        logger.info('removed %d segs with less than %d calls in at least one mod' % (deleted_segs, min_window_calls))

    meth_segs = pd.DataFrame.from_dict(meth_segs).T

    if 'seg_start' not in meth_segs:
        logger.warning('no methylation calls.')
        sys.exit()

    meth_segs['seg_start'] = pd.to_numeric(meth_segs['seg_start'])

    meth_segs.sort_values('seg_start', inplace=True)
    meth_segs['pos'] = np.arange(len(meth_segs.index))

    for sample in sample_names:
        meth_segs[sample] = smooth(np.asarray(meth_segs[sample + '_frac']), window_len=int(args.smoothwindowsize), window=args.smoothfunc)
        meth_segs[sample] = meth_segs[sample].rolling(window=10, min_periods=1, center=True).mean()

    coord_to_pos = {}
    for orig_loc, new_loc in zip(meth_segs['seg_start'], meth_segs['pos']):
        coord_to_pos[orig_loc] = new_loc

    # coverage bams (if present)

    cover_bams = {}
    if args.plot_coverage:
        cov_bams = []

        if args.plot_coverage.endswith('.bam'):
            cov_bams = args.plot_coverage.split(',')

        if args.plot_coverage.endswith('.txt'):
            with open(args.plot_coverage) as _:
                for fn in _:
                    cov_bams.append(fn.strip())

        for bam_fn in cov_bams:
            logger.info('gather coverage from %s...' % bam_fn)
            meth_seg_starts = [int(s.split(':')[1].split('-')[0]) for s in meth_segs.index]
            meth_seg_ends = [int(s.split(':')[1].split('-')[1]) for s in meth_segs.index]
            cover = bam_bincover(bam_fn, chrom, meth_seg_starts, meth_seg_ends, procs=int(args.procs), log=args.logcover)
            c_name = '.'.join(os.path.basename(bam_fn).split('.')[:-1])
            cover_bams[c_name] = smooth(np.asarray(cover), window_len=int(args.smoothwindowsize), window=args.smoothfunc)

    # highlights

    h_start = []
    h_end = []
    h_cpg_start = []
    h_cpg_end = []
    h_colors = []

    if args.highlight:
        h_colors = sns.color_palette(args.highlightpalette, n_colors=len(args.highlight.split(',')))

    if args.highlight:
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_s, h_e = map(int, h.split('-'))
            h_start.append(h_s)
            h_end.append(h_e)

            h_cpg_start.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_start[-1]))])
            h_cpg_end.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_end[-1]))])

    if args.highlight_bed:
        need_colour = 0

        with open(args.highlight_bed) as h_bed:
            for line in h_bed:
                c = line.strip().split()
                assert len(c) >= 3, 'malformed line in --highlight_bed: %s' % line.strip()

                if c[0] != chrom:
                    continue

                h_s, h_e = map(int, c[1:3])

                colour = None

                if len(c) > 3:
                    colour = c[3]
                else:
                    need_colour += 1
                
                h_start.append(h_s)
                h_end.append(h_e)

                h_cpg_start.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_start[-1]))])
                h_cpg_end.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_end[-1]))])

                h_colors.append(colour)

        more_colours = sns.color_palette(args.highlightpalette, n_colors=need_colour)

        for i, c in enumerate(h_colors):
            if c is None:
                h_colors[i] = more_colours.pop()
        
        logger.info('found %d highlights for %s' % (len(h_start), chrom))

    # mask

    readmask = []
    if args.readmask:
        for ivl in args.readmask.split(','):
            if ':' in ivl:
                ivl = ivl.split(':')[1]
            assert '-' in ivl, 'malformed --readmask interval(s): %s' % args.readmask

            readmask.append(list(map(int, ivl.split('-'))))

    # set up plot

    if end-start > 500000:
        logger.info('region size %s greater than 0.5 Mbp, setting --skip_align_plot True' % (end-start))
        args.skip_align_plot = True
    
    if args.force_align_plot == args.skip_align_plot == True:
        logger.info('--skip_align_plot overridden via --force_align_plot')
        args.skip_align_plot = False

    fig = plt.figure()
    height_ratios = [1,5,1,3]

    if args.skip_align_plot:
        height_ratios = [1,0,1,4]

    if args.plot_coverage:
        height_ratios.append(3)

    gs = gridspec.GridSpec(len(height_ratios),1,height_ratios=height_ratios, hspace=0)

    img_w = 16
    img_h = 8

    if args.skip_align_plot:
        if '--height' not in sys.argv[0]:
            args.height = 4.5

    if args.width:
        img_w = float(args.width)

    if args.height:
        img_h = float(args.height)

    if args.scale_fullwidth:
        img_w = img_w*scale_width

    fig.set_size_inches(img_w, img_h)

    if args.panelratios:
        p_ratios = list(map(int, args.panelratios.split(',')))
        if args.plot_coverage and len(p_ratios) < 5:
            logger.warning('--panelratios must have 5 values if used with --plot_coverage')
            sys.exit(1)

        gs = gridspec.GridSpec(len(height_ratios),1,height_ratios=p_ratios, hspace=0)

    sample_order = sorted(sample_names.keys())

    logger.info('sample order: %s' % ','.join(sample_order))

    sample_color = {}
    for i, sample in enumerate(sample_order):
        sample_color[sample] = sns.color_palette(args.samplepalette, n_colors=len(sample_order))[i]

    for sample in sample_color:
        if orig_bam[sample] in user_colours:
            sample_color[sample] = user_colours[orig_bam[sample]]

    if args.color_by_hp:
        if args.phased:
            hp_color = {}
            for hp in ('1','2'):
                hp_color[hp] = sns.color_palette(args.samplepalette, n_colors=2)[int(hp)-1]
            
            for sample in sample_color:
                hp = sample.split('.')[-1]
                assert hp in ('1','2')
                sample_color[sample] = hp_color[hp]

        else:
            logger.warning('--color_by_hp has no effect without --phase')

    cover_color = {}

    if args.plot_coverage:
        for i, c_name in enumerate(cover_bams.keys()):
            cover_color[c_name] = sns.color_palette(args.coverpalette, n_colors=len(cover_bams))[i]

    # plot genes

    ax0 = plt.subplot(gs[0])

    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.xaxis.set_ticks_position('top')

    gtf = None
    genes = [] 

    if args.gtf is not None:
        logger.info('building genes plot...')
        gtf = pysam.Tabixfile(args.gtf)
        genes = build_genes(gtf, chrom, start, end)

    exon_patches = []
    tx_lines = []

    genes_of_interest = []
    gene_specific_colours = {}

    if args.genes is not None:
        genes_of_interest = args.genes.strip().split(',')

        if '.' in args.genes and os.path.exists(args.genes):
            logger.info('assuming --genes is a file, reading gene list from %s' % args.genes)

            genes_of_interest = []
            with open(args.genes) as genes_file:
                for line in genes_file:
                    c = line.strip().split()
                    genes_of_interest.append(c[0])
                    if len(c) > 1:
                        gene_specific_colours[c[0]] = c[1]

    genemap = dd(Intersecter)

    if genes_of_interest:
        new_genes = {}
        for ensg in genes:
            if genes[ensg].name in genes_of_interest:
                new_genes[ensg] = genes[ensg]
        
        genes = new_genes

    if args.gtf:
        logger.info('%d genes in region' % len(genes))

    gene_colours = sns.color_palette(args.genepalette, n_colors=len(genes))

    for i, ensg in enumerate(genes):
        if not genes[ensg].has_tx():
            continue

        y = 1

        while genemap[y].find(genes[ensg].tx_start, genes[ensg].tx_end):
            y += 1
            if args.gene_track_height:
                if y > int(args.gene_track_height):
                    y = int(args.gene_track_height)
                    break
        
        if genes[ensg].name in gene_specific_colours:
            gene_colours[i] = gene_specific_colours[genes[ensg].name]

        tx_lines.append(matplotlib.lines.Line2D([genes[ensg].tx_start, genes[ensg].tx_end], [0.4+y, 0.4+y], color=gene_colours[i], zorder=2))

        genes[ensg].merge_exons()
        for exon_start, exon_end in genes[ensg].exons:
            exon_len = exon_end - exon_start
            exon_patches.append(matplotlib.patches.Rectangle([exon_start, y], exon_len, 0.8, edgecolor=gene_colours[i], facecolor=gene_colours[i], zorder=3))

        genemap[y].add_interval(Interval(genes[ensg].tx_start, genes[ensg].tx_end))

        if args.labelgenes:
            lg_x  = max(genes[ensg].tx_start, start)
            gtxt  = ax0.text(lg_x, y+0.8, genes[ensg].name, bbox=dict(boxstyle="round,pad=0.3", fc="lavender", alpha=0.5, lw=0), zorder=4)
            bb_w  = gtxt.get_tightbbox(renderer=fig.canvas.get_renderer()).width
            fig_w = fig.get_size_inches()[0]*fig.dpi
            txt_w = bb_w/fig_w*(end-start)
            gtxt.set_x(lg_x-txt_w/2)

    gene_height = max(3, len(genemap))

    if args.labelgenes:
        gene_height = max(6, len(genemap))

    ax0.set_ylim(0, gene_height)
    ax0.set_yticks([])

    highlight_patches = []

    if args.highlight_alpha:
        a = float(args.highlight_alpha)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_start):
                h_e = h_end[h]
                h_color = h_colors[h]
                ymin = min(ax0.get_ybound())
                yheight = max(ax0.get_ybound())-ymin
                if args.highlight_centerline:
                    clw = float(args.highlight_centerline)
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], 1, yheight, lw=clw, edgecolor=h_color, facecolor=h_color, alpha=a, zorder=1))
                else:
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax0.add_patch(h)
        
    for p in exon_patches:
        ax0.add_patch(p)

    for tx in tx_lines:
        ax0.add_line(tx)

    # per-read plot

    ax1 = plt.subplot(gs[1])
    ax1.set_xticks([])
    ax1.set_yticks([])

    if args.skip_align_plot:
        logger.info('skipped alignment plot due to --skip_align_plot')

    else:
        logger.info('building read alignment plot...')

        readstack = dd(list)

        max_y  = 1
        pack_y = 1

        masked_count = 0

        for bamname in reads:
            fetch_reads_bam = pysam.AlignmentFile(orig_bam[bamname])
            pos_cache = {}

            for read in fetch_reads_bam.fetch(chrom, start, end):
                if read.mapq < int(args.min_mapq):
                    continue

                if read.is_supplementary or read.is_secondary or read.is_duplicate:
                    if not args.allreads:
                        continue

                masked = False
                if len(readmask) > 0:
                    for mask_start, mask_end in readmask:
                        if read.reference_start >= mask_start and read.reference_end <= mask_end:
                            logger.debug('masked read: %s' % read.query_name)
                            masked_count += 1
                            masked = True
                
                if masked:
                    continue

                if read.query_name not in pos_cache:
                    pos_cache[read.query_name] = ([read.reference_start], [read.reference_end])
                else:
                    pos_cache[read.query_name][0].append(read.reference_start)
                    pos_cache[read.query_name][1].append(read.reference_end)

            for readname, read in reads[bamname].items():
                if readname not in pos_cache:
                    logger.debug('read %s not found in %s (skipped)' % (readname, bamname))
                    continue

                if read.call_count == 0:
                    continue

                read.ypos = max_y
                max_y += 1

                read.starts, read.ends = pos_cache[readname]
                readstack[bamname].append(read)

            fetch_reads_bam.close()

        if args.readmask:
            logger.info('masked %d reads due to --readmask' % masked_count)

        # read packing

        for bamname in readstack:
            reads = readstack[bamname]

            y = dd(list)

            for read in readstack[bamname]:
                y[read.ypos].append(read)

            for p in y:
                for q in y:
                    if p == q:
                        continue

                    for read_q in y[q]:
                            move = True

                            for read_p in y[p]:
                                if read_q.overlap(read_p):
                                    move = False

                            if move:
                                y[p].append(read_q)
                                y[q].remove(read_q)

            for p in y:
                if len(y[p]) > 0:
                    for read in sorted(y[p], key=lambda r: min(r.starts)):
                        read.ypos = pack_y
                    pack_y += 1

        ax1.set_ylim(0,pack_y+1)

        for bamname in readstack:
            for read in readstack[bamname]:
                for call_pos, call in read.meth_calls.items():
                    call_pos += start

                    if call == -1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec=sample_color[bamname], mfc='white', markersize=2, zorder=3)

                    if call == 1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec='black', mfc='black', markersize=2, zorder=3)

                for i in range(len(read.starts)):
                    readline_start = max(read.starts[i], start)
                    readline_end   = min(read.ends[i], end)

                    ax1.add_line(matplotlib.lines.Line2D([readline_start, readline_end], [read.ypos, read.ypos], zorder=2, color=sample_color[bamname], alpha=0.4))

        highlight_patches = []

        if args.highlight_alpha:
            a = float(args.highlight_alpha)
            if args.highlight or args.highlight_bed:
                for h, h_s in enumerate(h_start):
                    h_e = h_end[h]
                    h_color = h_colors[h]
                    ymin = min(ax1.get_ybound())
                    yheight = max(ax1.get_ybound())-ymin
                    if args.highlight_centerline:
                        clw = float(args.highlight_centerline)
                        highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], 1, yheight, lw=clw, edgecolor=h_color, facecolor=h_color, alpha=a, zorder=1))
                        pass
                    else:
                        highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

        for h in highlight_patches:
            ax1.add_patch(h)

    # plot correspondence between genome space and cpg space

    logger.info('building mod-space plot...')
    ax2 = plt.subplot(gs[2])
    ax3 = ax2.twiny()

    ax2.set_ylim(0,10)
    ax2.set_yticklabels([])

    x1 = []
    x2 = []

    step = int(args.modspace)

    for i, x in enumerate(meth_segs['seg_start']):
        if i in (0, len(meth_segs['seg_start'])-1):
            x2.append(x)
            x1.append(coord_to_pos[x])

        elif i % step == 0:
            x2.append(x)
            x1.append(coord_to_pos[x])

    
    ax2.vlines(x1, 0, 1, color='#777777', zorder=1)
    ax3.vlines(x2, 9, 10, color='#777777', zorder=1)

    if args.highlight or args.highlight_bed:
        for i in range(len(h_start)):
            orig_highlight_box = None
            cpg_highlight_box = None

            if args.highlight_centerline:
                clw = float(args.highlight_centerline)
                orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), 1.0, 1.0, lw=clw, edgecolor=h_colors[i], facecolor=h_colors[i], zorder=2)
                cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), 1.0, 1.0, lw=clw, edgecolor=h_colors[i], facecolor=h_colors[i], zorder=3)
            else:
                orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), h_end[i]-h_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=2)
                cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), h_cpg_end[i]-h_cpg_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=3)

            ax3.add_patch(orig_highlight_box)
            ax2.add_patch(cpg_highlight_box)

    for x1_x, x2_x in zip(x1, x2):
        link_end1 = (x1_x, 1)
        link_end2 = (x2_x, 9)

        l_col = '#777777'

        for i in range(len(h_start)):
            if x2_x >= h_start[i] and x2_x <= h_end[i]:
                l_col = h_colors[i]

        con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=l_col)
        ax3.add_artist(con)
    
    if args.highlight_centerline:
        clw = float(args.highlight_centerline)

        i = 0
        for x1_x, x2_x in zip(h_cpg_start, h_start):
            link_end1 = (x1_x, 1)
            link_end2 = (x2_x, 9)

            con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=h_colors[i], lw=clw)
            ax3.add_artist(con)
            i += 1

    ax0.set_xlim(ax3.get_xlim()) # sync axes between orig coords and gtf plot
    ax1.set_xlim(ax3.get_xlim())
    ax2.set_xticks([])
    ax3.set_xticks([])

    n_ticks = int(args.nticks) + 1
    tick_interval = (end-start)/n_ticks
    tick_list = list(range(int(ax0.get_xlim()[0]), int(ax0.get_xlim()[1]), int(tick_interval)))

    revised_tick_list = []
    for t in tick_list:
        if t < 0:
            revised_tick_list.append(0)
        else:
            revised_tick_list.append(t)

    tick_list = sorted(list(set(revised_tick_list)))

    xt_labels = [str(int(t)) for t in tick_list]
    xt_labels[0] = chrom

    ax0.set_xticks(tick_list)
    if n_ticks > 11:
        ax0.set_xticklabels(xt_labels, rotation=45)
    else:
        ax0.set_xticklabels(xt_labels)

    # meth frac plot

    logger.info('building meth frac plot...')
    ax4 = plt.subplot(gs[3])

    smoothalpha = float(args.smoothalpha)

    if smoothalpha > 1.0 or smoothalpha < 0.0:
        logger.warning('--smoothalpha must be between 0 and 1, set to 1.0')
        smoothalpha = 1.0

    for sample in sample_order:
        ax4.plot(meth_segs['pos'], meth_segs[sample], marker='', lw=4, color=sample_color[sample], zorder=2, label=sample, alpha=smoothalpha)

    ax4.legend()
    ax4.set_xlim(ax2.get_xlim())
    ax4.set_ylim((float(args.ymin),float(args.ymax)))

    highlight_patches = []

    if args.highlight_alpha:
        a = float(args.highlight_alpha)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_cpg_start):
                h_e = h_cpg_end[h]
                h_color = h_colors[h]
                ymin = min(ax4.get_ybound())
                yheight = max(ax4.get_ybound())-ymin
                if args.highlight_centerline:
                    clw = float(args.highlight_centerline)
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], 1, yheight, lw=clw, edgecolor=h_color, facecolor=h_color, alpha=a, zorder=1))
                else:
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax4.add_patch(h)

    # coverage plot (optional)

    if args.plot_coverage:
        logger.info('building coverage plot...')
        ax5 = plt.subplot(gs[4])

        for sample in cover_bams:
            ax5.plot(meth_segs['pos'], cover_bams[sample], marker='', lw=4, color=cover_color[sample], zorder=3, label=sample, alpha=smoothalpha)

        ax5.legend()
        ax5.set_xlim(ax2.get_xlim())
        ax5.set_ylim((float(args.cover_ymin), max(ax5.get_ybound())))

        highlight_patches = []

        if args.highlight_alpha:
            a = float(args.highlight_alpha)
            if args.highlight or args.highlight_bed:
                for h, h_s in enumerate(h_cpg_start):
                    h_e = h_cpg_end[h]
                    h_color = h_colors[h]
                    ymin = min(ax5.get_ybound())
                    yheight = max(ax5.get_ybound())-ymin
                    if args.highlight_centerline:
                        clw = float(args.highlight_centerline)
                        highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], 1, yheight, lw=clw, edgecolor=h_color, facecolor=h_color, alpha=a, zorder=1))
                    else:
                        highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

        for h in highlight_patches:
            ax5.add_patch(h)
            
    # output

    fn_prefix = '.%s_%d_%d.%s' % (chrom, start, end, ''.join(mods))

    if methbam:
        fn_prefix = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1]) + fn_prefix
        if len(args.bams.split(',')) > 1:
            fn_prefix += '.cohort'
    else:
        fn_prefix = '.'.join(os.path.basename(args.data).split('.')[:-1]) + fn_prefix

    if args.genes is not None:
        fn_prefix = '_'.join(args.genes.split(',')) + '.' + fn_prefix

    fn_prefix += '.s%d.w%d.m%d' % (args.smoothwindowsize, args.windows, args.modspace)

    if args.phased:
        fn_prefix += '.phased'

    if args.scale_fullwidth:
        fn_prefix += '.scale_width.%.3f' % scale_width

    outfn = fn_prefix

    if args.svg:
        outfn += '.region.meth.svg'
    else:
        outfn += '.region.meth.png'

    if args.outfile is not None:
        outfn = args.outfile
        if args.svg:
            if outfn.split('.')[-1] != 'svg':
                logger.warning('warning: %s does not have extension .svg, appending')
                outfn += '.svg'

    fig.savefig(outfn, bbox_inches='tight')
    logger.info('plot saved to %s' % outfn)


def composite(args):
    '''
    plot composite methylation profiles relative to a consensus element
    '''

    if not skbio_installed:
        sys.exit('scikit-bio is not installed but is required for this function. Please install e.g. via "pip install scikit-bio" or "conda install -c https://conda.anaconda.org/biocore scikit-bio"')

    te_ref_seq = single_seq_fa(args.teref).upper()

    assert os.path.exists(args.ref + '.fai'), 'ref fasta must be indexed'

    mod_names = []

    data = dd(list)

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams')

    if args.color_by_phase and not args.phased:
        sys.exit('must specify --phased to use --color_by_phase')

    if args.motif is not None:
        assert iupac(args.motif)

    methbam = False
    user_colours = {}

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both')

        with open(args.data) as _:
            for line in _:
                c = line.strip().split()
                if len(c) < 2:
                    logger.warning("required fields for -d/--data are: .bam file and methylation .db (generated by methylartist)")
                    sys.exit()

                bam, meth = c[:2]
                if ',' in meth:
                    sys.exit('multiple .db files per bam not supported for composite')

                data[bam] = meth

                for m in get_modnames(meth):
                    mod_names.append(m)

                if len(c) == 3:
                    bam_noext = '.'.join(os.path.basename(bam).split('.')[:-1])
                    user_colours[bam_noext] = c[2]

    if args.bams is not None:
        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')

        elif is_bam(args.bams):
            bams = args.bams.split(',')

        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    bams.append(line.strip().split()[0])

        for bam in bams:
            if ':' in bam:
                bam, _ = bam.split(':')

            if not os.path.exists(bam+'.bai'):
                sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

        mod_names = mods_methbam(bam)

    mod_names = list(set(mod_names))

    logger.info('found mod names: %s' % ','.join(mod_names))

    use_mod = None

    if len(mod_names) == 1:
        use_mod = mod_names[0]
        logger.info('using the one available mod: %s' % use_mod)

    if use_mod is None:
        if args.mod is None:
            sys.exit('please specify a modification via --mod\navailable mods are: %s' % ','.join(mod_names))
            
        use_mod = args.mod

        if use_mod not in mod_names:
            sys.exit('please specify a modification via --mod\navailable mods are: %s' % ','.join(mod_names))

    data_basename = None

    if methbam:
        data_basename = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1])
        if len(args.bams.split(',')) > 1:
            data_basename += '.cohort'

    else:
        data_basename = '.'.join(os.path.basename(args.data).split('.')[:-1])    

    seg_basename = '.'.join(os.path.basename(args.segdata).split('.')[:-1])

    outfn = data_basename + '.' + seg_basename

    if args.meanplot_cutoff:
        outfn += '.meanplot_cutoff_%d' % int(args.meanplot_cutoff)

    if args.phased:
        outfn += '.phased'

    outfn += '.composite'

    table_fn = outfn + '.table.tsv'

    if args.svg:
        outfn += '.svg'
    else:
        outfn += '.png'

    pool = mp.Pool(processes=int(args.procs))

    results = []

    with open(args.segdata) as bed:
        for line in bed:
            c = line.strip().split()
            
            seg_chrom, seg_start, seg_end = line.strip().split()[:3]        

            str_col = 0

            if c[3] in ('-', '+'):
                str_col = 3
            
            elif len(c) > 4 and c[4] in ('-', '+'):
                str_col = 4
            
            else:
                sys.exit('strand (+/-) not found in cols 4 or 5 of %s' % args.segdata)
            
            seg_strand = c[str_col]

            seg_start  = int(seg_start)
            seg_end    = int(seg_end)

            if args.phased:
                for phase in ('1','2'):
                    res = pool.apply_async(get_meth_profile_composite, [args, data, methbam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, phase])
                    results.append(res)

            else:
                res = pool.apply_async(get_meth_profile_composite, [args, data, methbam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, None])
                results.append(res)

    # collect mod data
    out_res = dd(list) # cache for --outelts

    for res in results:
        per_bam_res = res.get()

        for bam in per_bam_res:
            if per_bam_res[bam] is None:
                continue

            coord_meth_pos, meth_profile, elt_info = per_bam_res[bam]

            if len(coord_meth_pos) == 0:
                continue

            out_res[bam].append((coord_meth_pos, meth_profile, elt_info))

    # set bounds
    mod_start = 0
    mod_end = len(te_ref_seq)

    if args.start:
        mod_start = int(args.start)

    if args.end:
        mod_end = int(args.end)

    assert mod_start < mod_end

    # set up plot

    sample_color = {}
    if args.color_by_phase:
        for i, phase in enumerate(('phase1', 'phase2')):
            sample_color[phase] = sns.color_palette(args.palette, n_colors=2)[i]

    else:
        for i, bam in enumerate(out_res):
            sample_color[bam] = sns.color_palette(args.palette, n_colors=len(out_res))[i]

            if bam in user_colours:
                sample_color[bam] = user_colours[bam]

    fig = plt.figure()
    gs = None

    gs = gridspec.GridSpec(3,1,height_ratios=[3,1,8])

    g = 0

    # mean

    ax0 = plt.subplot(gs[g])
    g += 1
    ax0.set_ylim((-0.05,1.05))
    ax0.set_xlim((mod_start, mod_end))
    ax0.set_xticks([])

    meanplot_table = dd(dict)

    for bam in out_res:

        meth_by_coord = dd(list)

        per_elt_calls = []

        for coord_meth_pos, meth_profile, elt_info in out_res[bam]:
            per_elt_calls.append(len(coord_meth_pos))
            for c, m in zip(coord_meth_pos, meth_profile):
                meth_by_coord[c].append((m, elt_info))

        median_call_count = int(np.median(per_elt_calls))
        cutoff = median_call_count

        # v = sorted([len(c[0]) for c in meth_by_coord.values()], reverse=True)

        # cutoff = v[-1]

        # if len(v) > median_call_count:
        #     cutoff = v[median_call_count]
        
        if args.meanplot_cutoff:
            cutoff = int(args.meanplot_cutoff)

        logger.info('%s: median per element call count: %d' % (bam, median_call_count))
        logger.info('%s: per site call count cutoff: %d' % (bam, cutoff))

        for c in sorted(meth_by_coord.keys()):
            if len(meth_by_coord[c]) >= cutoff:
                for m, elt_info in meth_by_coord[c]:
                    u = str(uuid4())
                    meanplot_table[u]['chrom'], meanplot_table[u]['start'], meanplot_table[u]['end'], meanplot_table[u]['strand'] = elt_info.split('_')
                    if args.color_by_phase:
                        meanplot_table[u]['sample'] = bam.split('.')[-1]
                    else:    
                        meanplot_table[u]['sample'] = bam

                    meanplot_table[u]['coord'] = c
                    meanplot_table[u]['meth'] = m

    if len(meanplot_table) == 0:
        sys.exit('no successful alignments!')
    
    meanplot_table = pd.DataFrame.from_dict(meanplot_table).T
    meanplot_table['start'] = np.array(meanplot_table['start'], dtype=int)
    meanplot_table['end'] = np.array(meanplot_table['end'], dtype=int)
    meanplot_table['coord'] = np.array(meanplot_table['coord'], dtype=int)
    meanplot_table['meth'] = np.array(meanplot_table['meth'], dtype=float)

    #meanplot_table['meth'] = smooth(meanplot_table['meth'], window_len=int(args.smoothwindowsize)*4, window=args.smoothfunc)

    if args.output_table:
        meanplot_table.to_csv(table_fn, sep='\t', quoting=FALSE, index=False)
        logger.info('wrote per-site table to %s' % table_fn)


    ax0 = sns.lineplot(x='coord', y='meth', data=meanplot_table, ci='sd', lw=2, hue='sample', palette=sample_color)
    ax0.set_ylabel(args.meanplot_ylabel)

    # mod
    ax1 = plt.subplot(gs[g])
    g += 1
    ax1.set_xlim((mod_start, mod_end))

    box = matplotlib.patches.Rectangle([0, 0], mod_end-mod_start, 1.0, edgecolor='#555555', facecolor='#cfcfcf', zorder=1)
    ax1.add_patch(box)

    if args.blocks:
        with open(args.blocks) as blocks:
            for line in blocks:
                b_start, b_end, b_name, b_col = line.strip().split()
                b_start = int(b_start)
                b_end = int(b_end)

                box = matplotlib.patches.Rectangle([b_start, 0], b_end-b_start, 1.0, edgecolor='#555555', facecolor=b_col, zorder=2)
                ax1.add_patch(box)

    mod_locs = []

    motif = args.motif

    for i in range(len(te_ref_seq)-len(motif)):
        if i >= mod_start and i <= mod_end:
            if te_ref_seq[i:i+len(motif)] == motif:
                mod_locs.append(i)

    ax1.vlines(mod_locs, 0, 1, lw=1, colors=('#FF4500'), zorder=3, alpha=0.5)

    ax1.spines['bottom'].set_visible(False)
    ax1.spines['left'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.set_xlim((mod_start, mod_end))
    ax1.xaxis.set_ticks_position('top')

    # wiggles
    ax2 = plt.subplot(gs[g])
    g += 1

    z_max = len(out_res)*int(args.maxelts)

    for bam in out_res:
        if len(out_res[bam]) < int(args.minelts):
            sys.exit('fewer than --minelts (%d) usable elements (%d), giving up.' % (int(args.minelts),len(out_res[bam])))

        sample_size = int(args.maxelts)
        if sample_size > len(out_res[bam]):
            sample_size = len(out_res[bam])

        logger.info('sample %s has %d useable elements, will sample %d' % (bam, len(out_res[bam]), sample_size))

        for coord_meth_pos, meth_profile, elt_info in random.sample(out_res[bam], sample_size):
            if args.color_by_phase:
                bam = bam.split('.')[-1]

            ax2.plot(coord_meth_pos, meth_profile, lw=float(args.linewidth), alpha=float(args.alpha), color=sample_color[bam], zorder=random.randint(0,z_max))

    ax2.set_ylim((float(args.ymin),float(args.ymax)))
    ax2.set_xlim((mod_start, mod_end))
    ax2.set_xlabel('position')

    fig.set_size_inches(16, 6)

    if args.outfile is not None:
        outfn = args.outfile

    plt.savefig(outfn, bbox_inches='tight')
    logger.info('plotted to %s' % outfn)


def wgmeth(args):
    '''
    generates whole genome output in DSS or bedmethyl format
    '''

    meth_table = [None, None]
    seg_meth_table_store = dd(list)

    pool = mp.Pool(processes=int(args.procs))
    bin_size = int(args.binsize)

    methbam = False

    if args.methdb is None:
        if None in (args.ref, args.motif):
            logger.warning('--ref and --motif are required when using mod .bams (no --methdb)')
            sys.exit(1)
            
        methbam = True

    motifsize = None

    if args.motif is not None:
        assert iupac(args.motif)
        motifsize = len(args.motif)
        logger.info('motif size %d (%s)' % (motifsize, args.motif))

    if methbam:
        mods = sorted(mods_methbam(args.bam))
    else:
        mods = sorted(get_modnames(args.methdb))

    if methbam and len(mods) == 0:
        sys.exit('bam %s does not appear to contain MM/ML tags' % args.bam)

    if args.mod not in mods:
        if args.mod is None:
            logger.warning('must specify which to use with --mod, available mods: %s' % ','.join(mods))
        else:
            logger.warning('mod %s not in known mods for db: %s' % (args.mod, ','.join(mods)))
        sys.exit()

    if len(mods) > 1 and args.mod is None:
        logger.warning('more than one mod exists, need to pick one with -m/--mod: %s' % ','.join(mods))
        sys.exit()

    results = []

    with open(args.fai) as fai:
        for line in fai:
            chrom, chrlen = line.strip().split()[:2]
            chrlen = int(chrlen)

            if args.chrom:
                if args.chrom != chrom:
                    continue

            for seg_start in range(0, chrlen, bin_size):
                seg_end = seg_start + bin_size

                if seg_end > chrlen:
                    seg_end = chrlen

                seg_start = int(seg_start)
                seg_end = int(seg_end)

                res = pool.apply_async(get_meth_calls_wg, [args, args.bam, args.methdb, chrom, seg_start, seg_end, args.phased, args.mod, motifsize])

                results.append(res)


    for res in results:
        seg_meth_table = res.get()

        if seg_meth_table is None:
            continue

        if args.phased:
            for phase in (0,1):
                seg_meth_table_store[phase].append(pd.DataFrame.from_dict(seg_meth_table[phase]).T)

        else:
            seg_meth_table_store[0].append(pd.DataFrame.from_dict(seg_meth_table[0]).T)

    if args.phased:
        for phase in (0,1):
            meth_table[phase] = pd.concat(seg_meth_table_store[phase])

    else:
        meth_table[0] = pd.concat(seg_meth_table_store[0])

    if args.phased:
        for phase in (0,1):
            if len(meth_table[phase]) == 0:
                sys.exit('no calls for phase %d: is this data phased?' % phase)
            meth_table[phase]['pos'] = meth_table[phase].index
            meth_table[phase] = meth_table[phase].sort_values(['chr', 'pos'])

            outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1])

            if args.chrom:
                outfn += '.%s' % args.chrom

            if args.dss:
                outfn += '.%s.phase_%d.DSS.txt' % (str(args.mod), phase)
            else:
                outfn += '.%s.phase_%d.methyl.bed' % (str(args.mod), phase)

            if args.outfile is not None:
                outfn = args.outfile

            logger.info('writing %s' % outfn)

            if args.dss:
                meth_table[phase].to_csv(outfn, columns=['chr','pos','N','X'], index=False, sep='\t') # 1-based

            else:
                meth_table[phase]['score']  = meth_table[phase]['N']
                meth_table[phase]['score']  = meth_table[phase]['score'].where(meth_table[phase]['score'] <= 1000, 1000) # cap score at 1000
                meth_table[phase]['start']  = meth_table[phase]['pos']-1 # 0-based
                meth_table[phase]['end']    = meth_table[phase]['pos']
                meth_table[phase]['pct']    = meth_table[phase]['X']/meth_table[phase]['N']
                meth_table[phase]['name']   = '.'.join(args.methdb.split('.')[:-1])
                meth_table[phase]['strand'] = '.'
                meth_table[phase]['colour'] = '0,0,0'

                meth_table[phase].to_csv(outfn, columns=['chr','start','end','name', 'score', 'strand', 'start', 'end', 'colour', 'N', 'pct'], header=False, index=False, sep='\t')


    else:
        meth_table[0]['pos'] = meth_table[0].index
        meth_table[0] = meth_table[0].sort_values(['chr', 'pos'])

        outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1])

        if args.chrom:
            outfn += '.%s' % args.chrom

        if args.dss:
            outfn += '.%s.DSS.txt' % str(args.mod)
        else:
            outfn += '.%s.methyl.bed' % str(args.mod)

        if args.outfile is not None:
            outfn = args.outfile

        logger.info('writing %s' % outfn)

        if args.dss:
            meth_table[0].to_csv(outfn, columns=['chr','pos','N','X'], index=False, sep='\t')

        else:
            meth_table[0]['score']  = meth_table[0]['N']
            meth_table[0]['score']  = meth_table[0]['score'].where(meth_table[0]['score'] <= 1000, 1000)
            meth_table[0]['start']  = meth_table[0]['pos']-1
            meth_table[0]['end']    = meth_table[0]['pos']
            meth_table[0]['pct']    = meth_table[0]['X']/meth_table[0]['N']*100

            if args.methdb:
                meth_table[0]['name']   = '.'.join(args.methdb.split('.')[:-1])
            else:
                meth_table[0]['name']   = '.'.join(args.bam.split('.')[:-1])

            meth_table[0]['strand'] = '.'
            meth_table[0]['colour'] = '0,0,0'

            meth_table[0].to_csv(outfn, columns=['chr','start','end','name', 'score', 'strand', 'start', 'end', 'colour', 'N', 'pct'], header=False, index=False, sep='\t')

def main(args):
    logger.info('starting methylartist with command: %s' % ' '.join(sys.argv))
    args.func(args)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='methylartist: tools for exploring nanopore modified base data')
    subparsers = parser.add_subparsers(title="tool", dest="tool")
    subparsers.required = True

    __version__ = "1.2.7"
    parser.add_argument('-v', '--version', action='version', version='%(prog)s {version}'.format(version=__version__))

    parser_nanopolish    = subparsers.add_parser('db-nanopolish')
    parser_megalodon     = subparsers.add_parser('db-megalodon')
    parser_custom        = subparsers.add_parser('db-custom')
    parser_guppy         = subparsers.add_parser('db-guppy')
    parser_segmeth       = subparsers.add_parser('segmeth')
    parser_segplot       = subparsers.add_parser('segplot')
    parser_locus         = subparsers.add_parser('locus')
    parser_region        = subparsers.add_parser('region')
    parser_composite     = subparsers.add_parser('composite')
    parser_wgmeth        = subparsers.add_parser('wgmeth')
    parser_adjustcutoffs = subparsers.add_parser('adjustcutoffs')
    parser_scoredist     = subparsers.add_parser('scoredist')

    parser_segmeth.set_defaults(func=segmeth)
    parser_segplot.set_defaults(func=segplot)
    parser_locus.set_defaults(func=locus)
    parser_region.set_defaults(func=region)
    parser_composite.set_defaults(func=composite)
    parser_wgmeth.set_defaults(func=wgmeth)
    parser_nanopolish.set_defaults(func=db_nanopolish)
    parser_megalodon.set_defaults(func=db_megalodon)
    parser_custom.set_defaults(func=db_custom)
    parser_guppy.set_defaults(func=db_guppy)
    parser_adjustcutoffs.set_defaults(func=adjustcutoffs)
    parser_scoredist.set_defaults(func=scoredist)

    # options for methylation segments
    parser_segmeth.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_segmeth.add_argument('-b', '--bams', default=None, help='one or more .bams with Mm and Ml tags for modification calls (see samtags spec)')
    parser_segmeth.add_argument('-i', '--intervals', required=True, help='.bed file')
    parser_segmeth.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_segmeth.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_segmeth.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_segmeth.add_argument('--ref', default=None, help='reference genome .fa (build .fai index with samtools faidx) (required for mod bams)')
    parser_segmeth.add_argument('--motif', default=None, help='expected modification motif (e.g. CG for 5mCpG required for mod bams)')
    parser_segmeth.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_segmeth.add_argument('--excl_ambig', action='store_true', default=False, help='do not consider reads that align entirely within segment')
    parser_segmeth.add_argument('--spanning_only', action='store_true', default=False, help='only consider reads that span segment')
    parser_segmeth.add_argument('--primary_only', action='store_true', default=False, help='ignore non-primary alignments')
    parser_segmeth.add_argument('--phased', action='store_true', default=False, help='currently only considers two phases (diploid)')

    # options for methylation strip / violin plots
    parser_segplot.add_argument('-s', '--segmeth', required=True, help='output from segmeth.py')
    parser_segplot.add_argument('-m', '--samples', default=None, help='samples, comma delimited')
    parser_segplot.add_argument('-d', '--mods', default=None, help='mods, comma delimited')
    parser_segplot.add_argument('-c', '--categories', default=None, help='categories, comma delimited, need to match seg_name column from input')
    parser_segplot.add_argument('-v', '--violin', default=False, action='store_true')
    parser_segplot.add_argument('-g', '--ridge', default=False, action='store_true')
    parser_segplot.add_argument('-n', '--mincalls', default=10, help='minimum number of calls to include site (methylated + unmethylated) (default=10)')
    parser_segplot.add_argument('-r', '--minreads', default=1, help='minimum reads in interval (default = 1)')
    parser_segplot.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_segplot.add_argument('-a', '--group_by_annotation', default=False, action='store_true', help='group plots by annotation rather than by sample')
    parser_segplot.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_segplot.add_argument('--width', default=None, help='figure width (default = automatic)')
    parser_segplot.add_argument('--height', default=6, help='figure height (default = 6)')
    parser_segplot.add_argument('--pointsize', default=1, help='point size for scatterplot (default = 1)')
    parser_segplot.add_argument('--ymin', default=-0.15, help='ymin (default = -0.15)')
    parser_segplot.add_argument('--ymax', default=1.15, help='ymax (default = 1.15)')
    parser_segplot.add_argument('--ylabel', default='pct methylation', help='set label for y-axis (default: pct methylation)')
    parser_segplot.add_argument('--tiltlabel', default=False, action='store_true')
    parser_segplot.add_argument('--palette', default="tab10", help='palette for phases (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_segplot.add_argument('--ridge_alpha', default=1.0, help='alpha (tranparency) for ridge plot fills (default = 1.0)')
    parser_segplot.add_argument('--ridge_spacing', default=-0.25, help='ridge plot spacing (generally negative, default = -0.25)')
    parser_segplot.add_argument('--ridge_smoothing', default=0.5, help='smoothing parameter for ridge plot, bigger is smoother (default=0.5)')
    parser_segplot.add_argument('--svg', default=False, action='store_true')

    # options for locus-specific plots
    parser_locus.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line (whitespace-delimited)')
    parser_locus.add_argument('-b', '--bams', default=None, help='one or more .bams with MM and ML tags for modification calls (see samtags spec)')
    parser_locus.add_argument('-i', '--interval', required=True, help='chrom:start-end')
    parser_locus.add_argument('-g', '--gtf', default=None, help='genes or intervals to display in gtf format')
    parser_locus.add_argument('-l', '--highlight', default=None, help='format: start-end, (can be chrom:start-end but chrom is ignored) can comma-delimit multiple highlights')
    parser_locus.add_argument('-m', '--mods', default=None, help='mods, comma-delimited for >1 (default to all available mods)')
    parser_locus.add_argument('-s', '--smoothwindowsize', default=None, help='size of window for smoothing (default=auto)')
    parser_locus.add_argument('-t', '--slidingwindowstep', default=1, help='step size for initial sliding window (default=1)')
    parser_locus.add_argument('-p', '--panelratios',  default=None, help='Alter panel ratios: needs to be 5 comma-seperated integers. Default: 1,5,1,3,3')
    parser_locus.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_locus.add_argument('-r', '--ref', default=None, help='reference genome .fa (build .fai index with samtools faidx) (required for mod bams)')
    parser_locus.add_argument('-n', '--motif', default=None, help='expected modification motif (e.g. CG for 5mCpG required for mod bams)')
    parser_locus.add_argument('-c', '--plot_coverage', default=None, help='plot coverage from bam(s) (can be comma-delimited list)')
    parser_locus.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_locus.add_argument('--logcover', default=False, action='store_true', help='apply log2(count+1) to coverage data (--plot_coverage)')
    parser_locus.add_argument('--coverprocs', default=1, help='processes to use for coverage function (default=1)')
    parser_locus.add_argument('--bed', default=None, help='.bed file for additional annotations')
    parser_locus.add_argument('--highlight_bed', default=None, help='BED3+1 format (chrom, start, end, optional_colour) where colour (optional) must be intelligible to matplotlib')
    parser_locus.add_argument('--motifsize', default=2, help='mod motif size, only used with -b/--bams (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')
    parser_locus.add_argument('--allreads', default=False, action='store_true', help='show all alignments (secondary/supplementary alignments hidden by default)')
    parser_locus.add_argument('--phased', default=False, action='store_true', help='split samples into phases')
    parser_locus.add_argument('--ignore_ps', default=False, action='store_true', help='do not use phase set (PS) when plotting phased data (HP only)')
    parser_locus.add_argument('--color_by_hp', default=False, action='store_true', help='color samples by HP value (req --phased)')
    parser_locus.add_argument('--phase_labels', default=None, help='if --color_by_hp substitute HP tags for labels. Format HP:Label comma-delimited e.g.: 1:Father,2:Mother')
    parser_locus.add_argument('--include_unphased', default=False, action='store_true', help='include an "unphased" category if called with --phased')
    parser_locus.add_argument('--readmask', default=None, help='mask reads from being shown in interval(s) (start-end or chrom:start-end; chrom ignored). Can be comma-delimited.')
    parser_locus.add_argument('--readmarker', default='o', help='marker for (un)methylated glpyhs in read panel (matplotlib markers, default=o)')
    parser_locus.add_argument('--markeralpha', default=1.0, help='alpha (transparency) for (un)methylation marker (default=1.0)')
    parser_locus.add_argument('--readmarkersize', default=2.0, help='marker size for (un)methylated glpyhs in read panel (default=2.0)')
    parser_locus.add_argument('--readlinewidth', default=1.0, help='width for lines representing read alignments (default=1.0)')
    parser_locus.add_argument('--readlinealpha', default=0.5, help='alpha (transparency) for read mapping lines (default=0.4)')
    parser_locus.add_argument('--readopenmarkeredgecolor', default=None, help='edge color for open (unmethylated) markers in read plot (default = sample color)')
    parser_locus.add_argument('--slidingwindowsize', default=2, help='size of initial sliding window for coverage check (default=2)')
    parser_locus.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_locus.add_argument('--smoothalpha', default=1.0, help='alpha (transparency) value for smoothed plot (default = 1.0)')
    parser_locus.add_argument('--maskcutoff', default=1, help='read count masking cutoff (default=1)')
    parser_locus.add_argument('--maxmaskedfrac', default=1.0, help='skip smoothed plot if fraction of sample masked (--maskcutoff) > this value (default = 1.0)')
    parser_locus.add_argument('--nomask', default=False, action='store_true', help='skip drawing segment masks')
    parser_locus.add_argument('--mincalls', default=0, help='drop modspace positions if call count (meth+unmeth) < --mincalls (default=0)')
    parser_locus.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_locus.add_argument('--modspace', default=None, help='spacing between links in top panel (default=auto)')
    parser_locus.add_argument('--genes', default=None, help='genes of interest (comma delimited)')
    parser_locus.add_argument('--labelgenes', default=False, action='store_true', help='plot gene names')
    parser_locus.add_argument('--ymin', default=-0.05, help='y-axis minimum for smoothed plot (default = -0.05)')
    parser_locus.add_argument('--ymax', default=1.05, help='y-axis maximum for smoothed plot (default = 1.05)')
    parser_locus.add_argument('--cover_ymin', default=0, help='y-axis minimum for coverage plot (default = 0)')
    parser_locus.add_argument('--nticks', default=10, help='tick count (default=10)')
    parser_locus.add_argument('--statname', default=None, help='label for raw statistic plot')
    parser_locus.add_argument('--samplepalette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_locus.add_argument('--coverpalette', default="mako", help='colour palette name for coverage plot (default = "mako")')
    parser_locus.add_argument('--highlightpalette', default="Blues", help='colour palette name for highlights (default = "Blues")')
    parser_locus.add_argument('--genepalette', default="viridis", help='colour palette name for highlights (default = "viridis")')
    parser_locus.add_argument('--highlight_alpha', default=0.25, help='alpha for highlighting in panels (between 0 and 1, default = 0.25)')
    parser_locus.add_argument('--excl_ambig', action='store_true', default=False)
    parser_locus.add_argument('--primary_only', action='store_true', default=False, help='ignore non-primary alignments')
    parser_locus.add_argument('--unambig_highlight', action='store_true', default=False)
    parser_locus.add_argument('--width', default=16, help='image width (inches, default=16)')
    parser_locus.add_argument('--height', default=8, help='image width (inches, default=8)')
    parser_locus.add_argument('--svg', action='store_true')

    # options for region plots
    parser_region.add_argument('-i', '--interval', required=True, help='chrom:start-end')
    parser_region.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_region.add_argument('-b', '--bams', default=None, help='one or more .bams with MM and ML tags for modification calls (see samtags spec)')
    parser_region.add_argument('-n', '--motif', required=True, help='normalise window sizes to motif occurance')
    parser_region.add_argument('-r', '--ref', required=True, help='ref genome fasta, required if normalising windows with -n/--norm_motif')
    parser_region.add_argument('-g', '--gtf', default=None, help='genes or intervals to display in gtf format')
    parser_region.add_argument('-l', '--highlight', default=None, help='format: start-end, (can be chrom:start-end but chrom is ignored) can comma-delimit multiple highlights')
    parser_region.add_argument('-w', '--windows', default=None, help='set window count, default=auto')
    parser_region.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_region.add_argument('-m', '--mods', default=None, help='mods to consider (comma-delimited, default = all available)')
    parser_region.add_argument('-s', '--smoothwindowsize', default=None, help='size of window for smoothing (default=auto)')
    parser_region.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_region.add_argument('-c', '--plot_coverage', default=None, help='plot coverage from bam(s) (can be comma-delimited list)')
    parser_region.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_region.add_argument('--logcover', default=False, action='store_true', help='apply log2(count+1) to coverage data (--plot_coverage)')
    parser_region.add_argument('--allreads', default=False, action='store_true', help='show all alignments (secondary/supplementary alignments hidden by default)')
    parser_region.add_argument('--highlight_bed', default=None, help='BED3+1 format (chrom, start, end, optional_colour) where colour (optional) must be intelligible to matplotlib')
    parser_region.add_argument('--motifsize', default=2, help='mod motif size, only used with -b/--bams (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')
    parser_region.add_argument('--maxuncovered', default=50.0, help='maximum percentage of uncovered windows tolerated (default = 50.0)')
    parser_region.add_argument('--modspace', default=None, help='increase to increase spacing between links in top panel (default=auto)')
    parser_region.add_argument('--readmask', default=None, help='mask reads from being shown in interval(s) (start-end or chrom:start-end; chrom ignored). Can be comma-delimited.')
    parser_region.add_argument('--min_window_calls', default=1, help='minimum reads per window to include in plot (default = 1)')
    parser_region.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_region.add_argument('--smoothalpha', default=1.0, help='alpha (transparency) value for smoothed plot (default = 1.0)')
    parser_region.add_argument('--ymin', default=-0.05, help='y-axis minimum for smoothed plot (default = -0.05)')
    parser_region.add_argument('--ymax', default=1.05, help='y-axis maximum for smoothed plot (default = 1.05')
    parser_region.add_argument('--cover_ymin', default=0, help='y-axis minimum for coverage plot (default = 0)')
    parser_region.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_region.add_argument('--samplepalette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_region.add_argument('--coverpalette', default="mako", help='colour palette name for coverage plot (default = "mako")')
    parser_region.add_argument('--highlightpalette', default="Blues", help='colour palette name for highlights (default = "Blues")')
    parser_region.add_argument('--genepalette', default="viridis", help='colour palette name for highlights (default = "viridis")')
    parser_region.add_argument('--gene_track_height', default=None, help='maximum number of gene track layers')
    parser_region.add_argument('--highlight_alpha', default=0.25, help='alpha for highlighting in panels (between 0 and 1, default = 0.25)')
    parser_region.add_argument('--highlight_centerline', default=None, help='change highlight to line (specify width)')
    parser_region.add_argument('--panelratios',  default=None, help='Alter panel ratios: needs to be 4 (or 5 if --plot_coverage) comma-seperated integers. Default: 1,5,3,3')
    parser_region.add_argument('--nticks', default=10, help='tick count (default=10)')
    parser_region.add_argument('--skip_align_plot', default=False, action='store_true', help='blank alignment plot, useful if unneeded or for runtime.')
    parser_region.add_argument('--force_align_plot', default=False, action='store_true', help='retain alignment plot even over regions > 5Mbp where it would be disabled automatically')
    parser_region.add_argument('--genes', default=None, help='genes of interest (comma delimited)')
    parser_region.add_argument('--labelgenes', default=False, action='store_true', help='plot gene names')
    parser_region.add_argument('--width', default=16, help='image width (inches, default=16)')
    parser_region.add_argument('--height', default=8, help='image width (inches, default=8)')
    parser_region.add_argument('--phased', action='store_true', default=False, help='currently only considers two phases (diploid)')
    parser_region.add_argument('--primary_only', action='store_true', default=False, help='ignore non-primary alignments')
    parser_region.add_argument('--color_by_hp', default=False, action='store_true', help='color samples by HP value (req --phased)')
    parser_region.add_argument('--scale_fullwidth', default=None, help='scale plot output relative to value (e.g. use length of chrom 1)')
    parser_region.add_argument('--svg', action='store_true', default=False)

    # options for composite plots
    parser_composite.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_composite.add_argument('-b', '--bams', default=None, help='one or more .bams with MM and ML tags for modification calls (see samtags spec)')
    parser_composite.add_argument('-s', '--segdata', required=True, help='BED3+1: chrom, start, end, strand')
    parser_composite.add_argument('-r', '--ref', required=True, help='ref genome fasta')
    parser_composite.add_argument('-t', '--teref', required=True, help='TE ref fasta')
    parser_composite.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_composite.add_argument('-c','--palette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_composite.add_argument('-a', '--alpha', default=0.3, help='alpha (default: 0.3)')
    parser_composite.add_argument('-w', '--linewidth', default=1, help='line width (default: 1)')
    parser_composite.add_argument('-l', '--lenfrac', default=0.95, help='fraction of TE length that must align (default 0.95)')
    parser_composite.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_composite.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_composite.add_argument('--meanplot_ylabel', default='% methylation', help='set y-axis label on mean plot')
    parser_composite.add_argument('--meanplot_cutoff', default=None, help='override site coverage cutoff for mean plot (see output for automatic value)')
    parser_composite.add_argument('--mod', default=None, help='modification to plot (mod codes will be listed, default: infer from sample name')
    parser_composite.add_argument('--motif', default='CG', help='modified motif to highlight (default = CG)')
    parser_composite.add_argument('--blocks', default=None, help='blocks to highlight (txt file with start, end, name, hex colour)')
    parser_composite.add_argument('--start', default=None, help='start plotting at this base (default None)')
    parser_composite.add_argument('--end', default=None, help='end plotting at this base (default None)')
    parser_composite.add_argument('--mincalls', default=100, help='minimum call count to include elt (default = 100)')
    parser_composite.add_argument('--minelts', default=1, help='minimum output elements (default = 1)')
    parser_composite.add_argument('--maxelts', default=200, help='maximum output elements, if > max random.sample() (default = 200)')
    parser_composite.add_argument('--slidingwindowsize', default=10, help='size of sliding window for meth frac (default 10)')
    parser_composite.add_argument('--slidingwindowstep', default=1, help='step size for meth frac (default 1)')
    parser_composite.add_argument('--smoothwindowsize', default=8, help='size of window for smoothing (default 8)')
    parser_composite.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_composite.add_argument('--ymin', default=-0.05, help='y-axis minimum for smoothed plot')
    parser_composite.add_argument('--ymax', default=1.05, help='y-axis maximum for smoothed plot')
    parser_composite.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_composite.add_argument('--excl_ambig', action='store_true', default=False)
    parser_composite.add_argument('--primary_only', action='store_true', default=False, help='ignore non-primary alignments')
    parser_composite.add_argument('--phased', action='store_true', default=False, help='currently only considers two phases (diploid)')
    parser_composite.add_argument('--color_by_phase', default=False, action='store_true', help='color samples by HP value (req --phased)')
    parser_composite.add_argument('--output_table', default=False, action='store_true', help='output per-site data to table (.tsv)')
    parser_composite.add_argument('--svg', action='store_true', default=False)

    # options for whole genome output
    parser_wgmeth.add_argument('-b', '--bam', required=True, help='bam used for methylation calling')
    parser_wgmeth.add_argument('-d', '--methdb', default=None, help='methylation database')
    parser_wgmeth.add_argument('-s', '--binsize', default=1000000, help='bin size for parallelisation, default = 1000000')
    parser_wgmeth.add_argument('-f', '--fai', required=True, help='fasta index (.fai)')
    parser_wgmeth.add_argument('-m', '--mod', default=None, help='output for specific mod (names vary, see output for hints)')
    parser_wgmeth.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_wgmeth.add_argument('-c', '--chrom', default=None, help='limit analysis to one chromosome')
    parser_wgmeth.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_wgmeth.add_argument('-r', '--ref', default=None, help='reference genome .fa (build .fai index with samtools faidx) (required for mod bams)')
    parser_wgmeth.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_wgmeth.add_argument('--motif', default=None, help='expected modification motif (e.g. CG for 5mCpG required for mod bams)')
    parser_wgmeth.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_wgmeth.add_argument('--dss', default=False, action='store_true', help='output in DSS format (default = bedMethyl)')
    parser_wgmeth.add_argument('--phased', action='store_true', default=False, help='split output into phases (currently just 1,2)')
    parser_wgmeth.add_argument('--primary_only', action='store_true', default=False, help='ignore non-primary alignments')

    # options for custom db
    parser_custom.add_argument('-m', '--methdata', required=True, help='per-read methylation output table')
    parser_custom.add_argument('--header', default=False, action='store_true', help='input table has header')
    parser_custom.add_argument('--delimiter', default=None, help='column delimimter char (default = whitespace (i.e. tab or space)')
    parser_custom.add_argument('--readname', required=True, help='readname column number')
    parser_custom.add_argument('--chrom', required=True, help='chromosome column number')
    parser_custom.add_argument('--pos', required=True, help='genomic (i.e. on chromosome/contig) position column number, 0-based')
    parser_custom.add_argument('--strand', required=True, help='strand column number')
    parser_custom.add_argument('--modprob', required=True, help='column number for probability of modified base')
    parser_custom.add_argument('--canprob', default=None, help='column number for probability of canonical base (if not given, assume p=1-modprob)')
    parser_custom.add_argument('--modbasecol', default=None, help='column number for modified base/motif name (optional, can use --modbase instead)')
    parser_custom.add_argument('--modbase', default=None, help='specify modified base/motif name (overrides --modbasecol)')
    parser_custom.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_custom.add_argument('--minmodprob', default=0.8, help='probability threshold for calling modified base (default = 0.8)')
    parser_custom.add_argument('--mincanprob', default=None, help='probability threshold for calling canonical base (default = minmodprob)')
    parser_custom.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_custom.add_argument('--motifsize', default=2, help='mod motif size (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')

    # options for megalodon db
    parser_megalodon.add_argument('-m', '--methdata', required=True, help='megalodon per_read_text methylation output')
    parser_megalodon.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_megalodon.add_argument('-p', '--minprob', default=0.8, help='probability threshold for calling modified or unmodified base (default = 0.8)')
    parser_megalodon.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_megalodon.add_argument('--motifsize', default=2, help='mod motif size (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')

    # options for guppy db
    parser_guppy.add_argument('-s', '--samplename', required=True, help='name for sample')
    parser_guppy.add_argument('-f', '--fast5', required=True, help='fast5 with called bases')
    parser_guppy.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_guppy.add_argument('-m', '--motif', required=True, help='motif e.g. G[A]TC or [C]G')
    parser_guppy.add_argument('-n', '--modname', required=True, help='mod name in guppy fast5 modified base alphabet (5mC, 6mA, etc)')
    parser_guppy.add_argument('-b', '--bam', required=True, help='.bam file containing alignments of reads from fast5')
    parser_guppy.add_argument('-r', '--ref', required=True, help='reference genome fasta (samtools faidx indexed)')
    parser_guppy.add_argument('--minprob', default=0.8, help='probability threshold for calling modified or unmodified base (default = 0.8)')
    parser_guppy.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_guppy.add_argument('--include_unmatched', action='store_true', default=False, help='include sites where read base does not match genome base')
    parser_guppy.add_argument('--motifsize', default=2, help='mod motif size (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')
    parser_guppy.add_argument('--force', default=False, action='store_true')

    # options for nanopolish db
    parser_nanopolish.add_argument('-m', '--methdata', required=True, help='whole genome nanopolish methylation output, can be comma-delimited')
    parser_nanopolish.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_nanopolish.add_argument('-t', '--thresh', default=2.5, help='llr threshold (default = 2.5; if using --scalegroup the suggested setting is 2.0)')
    parser_nanopolish.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_nanopolish.add_argument('-s', '--scalegroup', default=False, action='store_true', help='scale threshold by number of CpGs in a group')
    parser_nanopolish.add_argument('-n', '--modname', default='CpG', help='modification type (tag if combining multiple mods, default = "CpG")')
    parser_nanopolish.add_argument('--motif', default='CG', help='mod motif (default = CG)')

    # options for adjusting methylation / unmethylation cutoffs in methylartist db
    parser_adjustcutoffs.add_argument('-d', '--db', required=True, help='methylartist database')
    parser_adjustcutoffs.add_argument('--mod', required=True, help='modification to plot (will list for user if incorrect)')
    parser_adjustcutoffs.add_argument('-m', '--methylated', required=True, help='mark as methylated above cutoff value')
    parser_adjustcutoffs.add_argument('-u', '--unmethylated', required=True, help='mark as unmethylated below cutoff value')

    # options for score distribution exploration funciton
    parser_scoredist.add_argument('-d', '--db', default=None, help='methylartist database(s), can be comma-delimited')
    parser_scoredist.add_argument('-b', '--bam', default=None, help='one or more .bam files with MM and ML tags for modification calls (see samtags spec)')
    parser_scoredist.add_argument('-n', '--n', default=1000000, help='sample size (default = 1000000)')
    parser_scoredist.add_argument('-m', '--mod', required=True, help='modification to plot (will list for user if incorrect)')
    parser_scoredist.add_argument('-r', '--ref', default=None, help='reference genome fasta (samtools faidx indexed)')
    parser_scoredist.add_argument('-o', '--outfile', default=None, help='output file name (default: generated from input)')
    parser_scoredist.add_argument('--motif', default=None, help='modified motif to highlight (e.g. CG)')
    parser_scoredist.add_argument('--xmin', default=None)
    parser_scoredist.add_argument('--xmax', default=None)
    parser_scoredist.add_argument('--lw', default=2, help='line width (default = 2)')
    parser_scoredist.add_argument('--palette', default="tab10", help='palette for phases (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_scoredist.add_argument('--svg', action='store_true', default=False)


    args = parser.parse_args()
    main(args)
