#!/home/taewing/miniconda3/envs/cyvcf2/bin/python

import os
import sys
import argparse
import gzip
import csv
import logging
import multiprocessing as mp
import matplotlib
import random
import itertools
import sqlite3
from matplotlib.colors import NoNorm

# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')

# Illustrator compatibility
new_rc_params = {'text.usetex': False, "svg.fonttype": 'none'}
matplotlib.rcParams.update(new_rc_params)

import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.patches import ConnectionPatch
import seaborn as sns

import numpy as np
import pandas as pd
import pysam
import scipy.stats as ss

from pandas.api.types import CategoricalDtype

skbio_installed = True

try:
    import skbio.alignment as skalign
    import skbio.sequence as skseq

except ModuleNotFoundError:
    skbio_installed = False

from uuid import uuid4
from collections import defaultdict as dd
from collections import Counter
from itertools import product
from operator import itemgetter
from copy import deepcopy

ont_fast5_api_installed = True

try:
    from ont_fast5_api.fast5_interface import get_fast5_file

except ModuleNotFoundError:
    ont_fast5_api_installed = False

from bx.intervals.intersection import Intersecter, Interval

FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Read:
    def __init__(self, read_name, cpg_loc, stat, methcall, modname, phase=None):
        self.read_name  = read_name
        self.llrs       = {}
        self.meth_calls = {}
        self.phase      = phase
        self.mod_names  = {}
        self.call_count = 0

        # used for locus per-read plot
        self.ypos   = None
        self.starts = []
        self.ends   = []

        self.add_mod(cpg_loc, stat, methcall, modname)

    def add_mod(self, cpg_loc, stat, methcall, modname):
        assert methcall in (-1,0,1)
        self.llrs[cpg_loc]       = stat
        self.meth_calls[cpg_loc] = methcall
        self.mod_names[cpg_loc]  = modname

        if methcall != 0:
            self.call_count += 1

    def overlap(self, other):
        return min(max(self.ends), max(other.ends)) - max(min(self.starts), min(other.starts)) > 0
    
    def mod_count(self):
        return len([c for c in self.meth_calls.values() if c == 1])

    def site_count(self):
        return len(self.meth_calls)

    def mod_frac(self):
        if self.site_count() == 0:
            return 0

        return self.mod_count()/self.site_count()


class Gene:
    def __init__(self, ensg, name):
        self.ensg = ensg
        self.name = name
        self.tx_start = None
        self.tx_end = None
        self.cds_start = None
        self.cds_end = None
        self.exons = []

    def add_exon(self, block):
        assert len(block) == 2

        if block[0] < block[1]:
            block[0], block[1] = block[1], block[0]

        self.exons.append(block)
        self.exons = sorted(self.exons, key=itemgetter(0))

    def add_tx(self, block):
        assert len(block) == 2
        assert block[0] < block[1]
        if self.tx_start is None or self.tx_start > block[0]:
            self.tx_start = block[0]

        if self.tx_end is None or self.tx_end < block[1]:
            self.tx_end = block[1]

    def add_cds(self, block):
        assert len(block) == 2
        if block[0] > block[1]:
            logger.warning('CDS block start > end in gene %s' % self.ensg)
            return None

        if self.cds_start is None or self.cds_start > block[0]:
            self.cds_start = block[0]

        if self.cds_end is None or self.cds_end < block[1]:
            self.cds_end = block[1]

    def has_tx(self):
        return None not in (self.tx_start, self.tx_end)

    def has_cds(self):
        return None not in (self.cds_start, self.cds_end)

    def merge_exons(self):
        new_exons = []
        if len(self.exons) == 0:
            return

        last_block = self.exons[0]

        for block in self.exons[1:]:
            if min(block[1], last_block[1]) - max(block[0], last_block[0]) > 0: # overlap
                last_block = [min(block[0], last_block[0]), max(block[1], last_block[1])]

            else:
                new_exons.append(last_block)

            last_block = block

        new_exons.append(last_block)

        self.exons = new_exons


RC = {'A':'T', 'T':'A', 'C':'G', 'G':'C', 'N':'N'}

class Motif:
    def __init__(self, motif):

        self.motif = motif

        self.left_bases  = ''
        self.right_bases = ''
        self.key_base    = ''
        self.full_seq    = ''

        assert len(motif.split('[')) == 2, 'bad motif syntax: %s' % motif
        assert len(motif.split(']')) == 2, 'bad motif syntax: %s' % motif

        self.left_base   = motif.split('[')[0]
        self.right_bases = motif.split(']')[1]
        self.key_base    = motif.split('[')[1].split(']')[0]
        self.full_seq    = self.left_bases + self.key_base + self.right_bases
        
        assert self.key_base in 'ATCG', 'bad motif syntax: %s' % motif


    def match_seq(self, seq):
        ''' returns 0-based coords of key base for matches in seq '''

        sites = dd(list) 

        seq = list(seq.upper())


        for i in range(len(seq)-len(self.full_seq)):
            nmer = ''.join(seq[i:i+len(self.full_seq)])
            if nmer == self.full_seq:
                sites[i + len(self.left_bases)].append(self.motif)

        return sites
        

class Mod:
    def __init__(self, read_id, read_pos, read_base, sites, mods):
        self.read_id     = read_id
        self.read_pos    = read_pos
        self.read_base   = read_base
        self.sites       = sites
        self.mods        = mods
        self.chrom       = None
        self.genome_pos  = None
        self.genome_base = None
        self.aln_strand  = None

    def __str__(self):
        return '\t'.join((self.read_id, self.chrom, str(self.genome_pos), self.genome_base, self.aln_strand, str(self.read_pos), self.read_base, ','.join(self.sites), '\t'.join(map(str, self.mods))))


def base_field(mod_metadata):
    long_names = mod_metadata['modified_base_long_names'].split()

    bases = []

    i = 0
    alpha_field_name = None

    if 'output_alphabet' in mod_metadata:
        alpha_field_name = 'output_alphabet'

    elif 'modified_base_alphabet' in mod_metadata:
        alpha_field_name = 'modified_base_alphabet'

    for base in list(mod_metadata[alpha_field_name]):
        if base not in 'ACTG':
            bases.append(long_names[i])
            i += 1
        else:
            bases.append(base)

    return bases


def pandas_df(out, mod_base, can_base):
    data = dd(dict)

    header = []

    for i, line in enumerate(out):
        if i == 0:
            header = line.strip().split()
            continue

        for j, val in enumerate(line.strip().split('\t')):
            data[i][header[j]] = val

    data = pd.DataFrame.from_dict(data).T
    data = plot_data = pd.DataFrame(data.to_dict())

    data['can_log_prob'] = np.log((pd.to_numeric(data[can_base])+1)/256)
    data['mod_log_prob'] = np.log((pd.to_numeric(data[mod_base])+1)/256)
    data['modstat'] = data['mod_log_prob'] - data['can_log_prob']

    return data


def guppy_f5_fetch(fast5, bam_db, args):
    bam = pysam.AlignmentFile(args.bam)
    ref = pysam.Fastafile(args.ref)

    bases = []
    modbases = dd(list)

    m = Motif(args.motif)

    out = []

    with get_fast5_file(args.fast5 + '/' + fast5, mode="r") as f5:
        for read_id in f5.get_read_ids():
            read = f5.get_read(read_id)
            latest_basecall = read.get_latest_analysis('Basecall_1D')

            mod_base_table = read.get_analysis_dataset(latest_basecall, 'BaseCalled_template/ModBaseProbs')
            called_base_table = read.get_analysis_dataset(latest_basecall, 'BaseCalled_template/Fastq')

            mod_table_path = '{}/BaseCalled_template/ModBaseProbs'.format(latest_basecall)
            called_table_path = '{}/BaseCalled_template/Fastq'.format(latest_basecall)

            mod_metadata = read.get_analysis_attributes(mod_table_path)
            called_metadata = read.get_analysis_attributes(called_table_path)

            if None in (mod_metadata, called_metadata):
                logger.error('fast5 %s does not seem to contain basecalling information' % fast5)
                return None

            if len(bases) == 0:
                bases = base_field(mod_metadata)

                mod_names = [b for b in bases if b not in ('ATCG')]

                if args.modname not in mod_names:
                    logger.error('mod "%s" not found in metadata. Available mod names: %s' % (args.modname, ','.join(mod_names)))
                    return None

                header = "readname\tchrom\tpos\tgenome_base\tstrand\tread_pos\tread_base\tmotif\t" + '\t'.join(bases) + "\n"

            seq = called_base_table.split('\n')[1]

            assert len(seq) == len(mod_base_table)

            sites = m.match_seq(seq)

            for i, mods in enumerate(mod_base_table):
                if i in sites:
                    modbases[read_id].append(Mod(read_id, i, seq[i], sites[i], mods))

    logger.info('processed %d reads from %s' % (len(modbases), fast5))

    logger.info('parsing alignments from %s' % bam_db)


    out.append(header)

    conn = sqlite3.connect(bam_db)
    c = conn.cursor()

    for readname in modbases:
        for row in c.execute("SELECT sam FROM bam WHERE readname='%s'" % readname):
            read = pysam.AlignedSegment.fromstring(row[0], bam.header)

            pos_lookup = {}
            for ap in read.get_aligned_pairs():
                if None not in ap:
                    pos_lookup[ap[0]] = ap[1]

            for modbase in modbases[read.query_name]:
                modbase.aln_strand = '+'
                aligned_read_pos = modbase.read_pos

                if read.is_reverse:
                    modbase.aln_strand = '-'
                    aligned_read_pos = (len(read.seq) - modbase.read_pos)-1 # SAM format stores read on + strand

                if aligned_read_pos in pos_lookup:
                    modbase.chrom = read.reference_name
                    modbase.genome_pos = pos_lookup[aligned_read_pos]
                    modbase.genome_base = ref.fetch(modbase.chrom, modbase.genome_pos, modbase.genome_pos+1)

                    modbase.genome_base = modbase.genome_base.upper()

                    if args.include_unmatched:
                        out.append(str(modbase) + '\n')
                        
                    else:
                        if modbase.aln_strand == '+' and modbase.genome_base == modbase.read_base:
                            out.append(str(modbase) + '\n')

                        if modbase.aln_strand == '-' and RC[modbase.genome_base] == modbase.read_base:
                            out.append(str(modbase) + '\n')

    
    data = pandas_df(out, args.modname, m.key_base)
    out_fn = args.fast5 + '/' + '.'.join(fast5.split('.')[:-1]) + '.gupmod.tsv'
    logger.info('writing output from %s to %s' % (fast5, out_fn))
    data.to_csv(out_fn, sep='\t', index=False)

    return out_fn


def exclude_ambiguous_reads(fn, chrom, start, end, min_mapq=10, spanning_only=False):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        p = read.get_reference_positions()
        if p[0] < start or p[-1] > end:
            if read.mapq >= min_mapq:
                if spanning_only:
                    if read.reference_start < start and read.reference_end > end:
                        reads.append(read.query_name)
                else:
                    reads.append(read.query_name)

    return reads


def get_ambiguous_reads(fn, chrom, start, end, min_mapq=10, w=50):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        p = read.get_reference_positions()
        if read.mapq < min_mapq or (p[0] > start-w and p[-1] < end+w):
            reads.append(read.query_name)

    return reads


def get_reads(fn, chrom, start, end, min_mapq=10, spanning_only=False):
    reads = []

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        if read.mapq >= min_mapq:
            if spanning_only:
                if read.reference_start < start and read.reference_end > end:
                    reads.append(read.query_name)
            else:
                reads.append(read.query_name)

    return reads


def get_phased_reads(fn, chrom, start, end, min_mapq=10, tag_untagged=False, ignore_tags=False, HP_only=False):
    reads = {}

    bam = pysam.AlignmentFile(fn)
    for read in bam.fetch(chrom, start, end):
        if read.mapq >= min_mapq:    
            phase = None

            if tag_untagged or ignore_tags:
                phase = 'unphased'

            HP = None
            PS = None

            if not ignore_tags:
                for tag in read.get_tags():
                    if tag[0] == 'HP':
                        HP = tag[1]
                    if tag[0] == 'PS':
                        PS = tag[1]

            if None not in (HP, PS):
                phase = str(PS) + ':' + str(HP)

                if HP_only:
                    phase = str(HP)

            reads[read.query_name] = phase

    return reads


def single_seq_fa(fn):
    with open(fn, 'r') as fa:
        seq   = ''
        for line in fa:
            if line.startswith('>'):
                assert seq == '', 'input fa must have only one entry'
            else:
                seq = seq + line.strip()

    return seq


def rc(dna):
    ''' reverse complement '''
    complements = str.maketrans('acgtrymkbdhvACGTRYMKBDHV', 'tgcayrkmvhdbTGCAYRKMVHDB')
    return dna.translate(complements)[::-1]


def get_modnames(meth_db):
    if not os.path.exists(meth_db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_db)

    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    mod_names = []

    for row in c.execute("SELECT DISTINCT mod FROM modnames"):
        mod_names.append(row[0])

    return mod_names


def get_cutoffs(meth_db, mod):
    if not os.path.exists(meth_db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_db)

    conn = sqlite3.connect(meth_db)
    c = conn.cursor()

    for row in c.execute("SELECT upper,lower FROM cutoffs WHERE modname='%s'" % mod):
        return row


def densecall_filter(reads, max_density=.7):
    ''' reads is a dict of Read objects with modified base calls '''
    out = {}
    filtered_count = 0

    for name, read in reads.items():
        if read.mod_frac() <= max_density:
            out[name] = read
        else:
            filtered_count += 1
    
    logger.info('filtered %d reads via --max_read_density %f' % (filtered_count, max_density))

    return out


def mods_methbam(bam_fn):
    bam = pysam.AlignmentFile(bam_fn)

    mm_warned = False

    for rec in bam.fetch():
        try:
            mm = str(rec.get_tag('Mm')).rstrip(';')
        except KeyError:
            if not mm_warned:
                logger.debug('cannot find Mm tag in at least one read, ensure this bam has Mm and Ml tags!')
                mm_warned = True
                continue
        
        mod_strings = mm.split(';')

        mods = []

        for mod_string in mod_strings:
            m = mod_string.split(',')
            mod_info = m[0]

            mod_strand = '+'
            if '-' in mod_info:
                mod_strand = '-'
            
            mod_base, mod_type = mod_info.split(mod_strand)

            mods.append(mod_type)
        
        return mods # assumes all mods are represented for each read that has an Mm tag
        

def split_ml(mod_strings, ml):
    mls = []

    total_ms = sum([len(ms.split(',')[1:]) for ms in mod_strings])

    assert total_ms == len(ml), 'mod bam formatting error'

    i = 0
    for mod_string in mod_strings:
        m = mod_string.split(',')[1:] # discard first item (desc of mod base)
        mls.append(ml[i:i+len(m)])
        i += len(m)

        assert len(m) == len(mls[-1]), 'mod bam formatting error'
    
    return mls


def parse_methbam(bam_fn, reads, chrom, start, end, motifsize=2, meth_thresh=0.8, can_thresh=0.8):
    bam = pysam.AlignmentFile(bam_fn)

    mm_warned = False

    for rec in bam.fetch(chrom, start, end):
        if rec.is_unmapped:
            continue

        if rec.qname not in reads:
            continue

        ap = dict([(k, v) for (k, v) in rec.get_aligned_pairs() if None not in (k,v)])

        try:
            mm = str(rec.get_tag('Mm')).rstrip(';')
        except KeyError:
            if not mm_warned:
                logger.debug('cannot find Mm tag in at least one read, ensure this bam has Mm and Ml tags!')
                mm_warned = True
                continue

        try:
            ml = rec.get_tag('Ml')
        except KeyError:
            continue

        mod_strings = mm.split(';')
        mls = split_ml(mod_strings, ml)

        seq = rec.seq

        if rec.is_reverse:
            seq = rc(seq)

        for mod_string, scores in zip(mod_strings, mls):
            m = mod_string.split(',')
            mod_info = m[0]

            mod_relpos = list(map(int, m[1:]))

            mod_strand = '+'

            if '-' in mod_info:
                mod_strand = '-'

            try:
                mod_base, mod_type = mod_info.split(mod_strand)
                mod_type = mod_type.rstrip('?.')
            except ValueError:
                logger.debug('%s: malformed mod string for read %s (%s) skipped.' % (bam_fn, rec.qname, mod_info))
                continue

            assert len(mod_type) == 1, 'multiple modfications listed this way: %s is not yet supported, please send me an example!' % mod_info

            base_pos = [i for i, b in enumerate(seq) if b == mod_base]

            i = -1

            for skip, score in zip(mod_relpos, scores):
                i += 1

                if skip > 0:
                    i += skip

                genome_pos = ap[base_pos[i]]

                if rec.is_reverse:
                    genome_pos = ap[len(rec.seq)-base_pos[i]-1] - (int(motifsize)-1)

                p_mod = score/255
                p_can = 1-p_mod

                assert p_mod <= 1.0

                methstate = 0

                if p_mod > meth_thresh:
                    methstate = 1

                if p_can > can_thresh:
                    methstate = -1

                yield (rec.qname, rec.reference_name, genome_pos, p_mod, methstate, mod_type)


def get_segmeth_calls(args, bam_fn, mod_names, meth_dbs, chrom, seg_start, seg_end, seg_name, seg_strand, phase, methbam):
    c_lookup = {}

    if not methbam:
        for meth_db in meth_dbs:
            conn = sqlite3.connect(meth_db)
            c_lookup[meth_db] = conn.cursor()

    if 'spanning_only' not in vars(args):
        args.spanning_only=False

    reads = []
    if hasattr(args, 'excl_ambig') and args.excl_ambig:
        reads = exclude_ambiguous_reads(bam_fn, chrom, seg_start, seg_end, min_mapq=int(args.min_mapq), spanning_only=args.spanning_only)
    else:
        reads = get_reads(bam_fn, chrom, seg_start, seg_end, min_mapq=int(args.min_mapq), spanning_only=args.spanning_only)

    reads = list(set(reads))

    if phase:
        phased_reads_dict = get_phased_reads(bam_fn, chrom, seg_start, seg_end, tag_untagged=(phase=='unphased'), min_mapq=int(args.min_mapq), HP_only=True)
        reads = [r for r in reads if r in phased_reads_dict and phased_reads_dict[r] == phase]

    reads = set(reads)
    seg_reads = {}

    if methbam:
        for row in parse_methbam(bam_fn, reads, chrom, seg_start, seg_end, meth_thresh=0.8, can_thresh=0.8):
            index, cg_chrom, cg_start, stat, methstate, modname = row

            if chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start > seg_end:
                continue

            cg_seg_start = cg_start - seg_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    else:
        for index in reads:
            for meth_db in meth_dbs:
                c = c_lookup[meth_db]
                for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

                    cg_chrom, cg_start, stat, methstate, modname = row

                    if chrom != cg_chrom:
                        continue

                    if cg_start < seg_start or cg_start > seg_end:
                        continue

                    cg_seg_start = cg_start - seg_start

                    if index not in seg_reads:
                        seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
                    else:
                        seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

    seg_result = {}

    for modname in mod_names:
        seg_meth_calls = dd(int)

        for name, read in seg_reads.items():
            for loc, call in read.meth_calls.items():
                if read.mod_names[loc] == modname:
                    seg_meth_calls[call] += 1

        seg_result[modname] = seg_meth_calls

    read_count = len(seg_reads)

    return seg_result, (chrom, seg_start, seg_end, seg_name, seg_strand, read_count)


def get_meth_locus(args, bam, meth_dbs, mod, phase=None, methbam=False, HP_only=False):
    # used for locus plots

    if phase:
        logger.info('fetching reads from %s with mod %s on phase %s' % (bam, mod, phase))
    else:
        logger.info('fetching reads from %s with mod %s' % (bam, mod))

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos = args.interval.split(':')
    elt_start, elt_end = pos.split('-')

    elt_start = int(elt_start.replace(',',''))
    elt_end = int(elt_end.replace(',',''))

    c_lookup = {}

    if not methbam:
        for meth_db in meth_dbs:
            if not os.path.exists(meth_db):
                sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_db)

            conn = sqlite3.connect(meth_db)
            c_lookup[meth_db] = conn.cursor()

    # get list of relevant reads (exludes reads not anchored outside interval)
    reads = []
    if hasattr(args, 'excl_ambig') and args.excl_ambig:
        reads = exclude_ambiguous_reads(bam, chrom, elt_start, elt_end, min_mapq=int(args.min_mapq))
    else:
        reads = get_reads(bam, chrom, elt_start, elt_end, min_mapq=int(args.min_mapq))
        
    reads = list(set(reads))

    if phase:
        phased_reads_dict = get_phased_reads(bam, chrom, elt_start, elt_end, tag_untagged=(phase=='unphased'), min_mapq=int(args.min_mapq), HP_only=HP_only)
        reads = [r for r in reads if r in phased_reads_dict and phased_reads_dict[r] == phase]

    if hasattr(args, 'unambig_highlight') and hasattr(args, 'highlight'): 
        if args.unambig_highlight and args.highlight:
            h_coords = []
            for h in args.highlight.split(','):
                if ':' in h:
                    h = h.split(':')[-1]
                    
                h_coords += map(int, h.split('-'))

            h_coords.sort()

            h_start, h_end = h_coords[0], h_coords[-1]

            excl_reads = get_ambiguous_reads(bam, chrom, h_start, h_end, min_mapq=int(args.min_mapq))

            new_reads = []
            for read in reads:
                if read not in excl_reads:
                    new_reads.append(read)

            reads = set(new_reads)

    seg_reads = {}

    if methbam:
        for row in parse_methbam(bam, reads, chrom, elt_start, elt_end, motifsize=args.motifsize, meth_thresh=0.8, can_thresh=0.8):
            index, cg_chrom, cg_start, stat, methstate, modname = row

            if methstate == 0:
                continue

            if modname != mod:
                continue

            if chrom != cg_chrom:
                continue

            if cg_start < elt_start or cg_start > elt_end:
                continue

            cg_seg_start = cg_start - elt_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    else:
        for index in reads:
            for meth_db in meth_dbs:
                c = c_lookup[meth_db]

                for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):
                    cg_chrom, cg_start, stat, methstate, modname = row

                    if methstate == 0:
                        continue

                    if modname != mod:
                        continue

                    if chrom != cg_chrom:
                        continue

                    if cg_start < elt_start or cg_start > elt_end:
                        continue

                    cg_seg_start = cg_start - elt_start

                    if index not in seg_reads:
                        seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
                    else:
                        seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

    return seg_reads


def get_meth_profile_composite(args, data, methbam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, phase):
    per_bam_results = {}

    for bam in data:
        logger.info('profiling %s: %s:%d-%d:%s:%s phase: %s' % (bam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, str(phase)))

        te_ref_seq = single_seq_fa(args.teref)
        ref = pysam.Fastafile(args.ref)

        conn = None
        c = None

        if not methbam:
            conn = sqlite3.connect(data[bam])
            c = conn.cursor()

        reads = []
        if args.excl_ambig:
            reads = exclude_ambiguous_reads(bam, seg_chrom, seg_start, seg_end, min_mapq=int(args.min_mapq))
        else:
            reads = get_reads(bam, seg_chrom, seg_start, seg_end, min_mapq=int(args.min_mapq))

        if phase:
            phased_reads_dict = get_phased_reads(bam, seg_chrom, seg_start, seg_end, tag_untagged=(phase=='unphased'), min_mapq=int(args.min_mapq), HP_only=True)
            reads = [r for r in reads if r in phased_reads_dict and phased_reads_dict[r] == phase]

        reads = set(reads)

        seg_reads = {}

        if methbam:
            for row in parse_methbam(bam, reads, seg_chrom, seg_start, seg_end, motifsize=len(args.motif), meth_thresh=0.8, can_thresh=0.8):
                index, cg_chrom, cg_start, stat, methstate, modname = row

                if seg_chrom != cg_chrom:
                    continue

                if cg_start < seg_start or cg_start > seg_end:
                    continue

                cg_seg_start = cg_start - seg_start

                if index not in seg_reads:
                    seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
                else:
                    seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

        else:
            for index in reads:
                for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

                    cg_chrom, cg_start, stat, methstate, modname = row

                    if seg_chrom != cg_chrom:
                        continue

                    if cg_start < seg_start or cg_start > seg_end:
                        continue

                    if modname != use_mod:
                        continue

                    cg_seg_start = cg_start - seg_start

                    if index not in seg_reads:
                        seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
                    else:
                        seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

        if args.max_read_density is not None:
            seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

        meth_table = dd(dict)
        sample = '.'.join(bam.split('.')[:-1])
        call_count = 0

        for name, read in seg_reads.items():
            for loc in read.llrs.keys():
                uuid = str(uuid4())
                meth_table[uuid]['loc'] = loc
                meth_table[uuid]['modstat'] = read.llrs[loc]
                meth_table[uuid]['read'] = name
                meth_table[uuid]['sample'] = sample
                meth_table[uuid]['call'] = read.meth_calls[loc]

                if read.meth_calls[loc] in (1, -1):
                    call_count += 1

        if call_count < args.mincalls:
            logger.warning('too few calls on seg: %s:%d-%d (%d)' % (seg_chrom, seg_start, seg_end, call_count))
            per_bam_results[bam] = None
            continue

        meth_table = pd.DataFrame.from_dict(meth_table).T
        meth_table['loc'] = pd.to_numeric(meth_table['loc'])
        meth_table['modstat'] = pd.to_numeric(meth_table['modstat'])

        meth_table['orig_loc'] = meth_table['loc']
        meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

        coord_to_cpg = {}
        cpg_to_coord = {}
        for orig_loc, new_loc in zip(meth_table['orig_loc'], meth_table['loc']):
            coord_to_cpg[orig_loc] = new_loc
            cpg_to_coord[new_loc]  = orig_loc

        windowed_methfrac, meth_n = slide_window(meth_table, sample, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))

        if len(windowed_methfrac) <= int(args.smoothwindowsize):
            logger.warning('too few sites after windowing: %s:%d-%d' % (seg_chrom, seg_start, seg_end))
            per_bam_results[bam] = None
            continue

        smoothed_methfrac = smooth(np.asarray(list(windowed_methfrac.values())), window_len=int(args.smoothwindowsize), window=args.smoothfunc)

        coord_meth_pos = []

        cpg_meth_pos = list(windowed_methfrac.keys())

        for cpg in cpg_meth_pos:
            if cpg in cpg_to_coord:
                if seg_strand == '+':
                    coord_meth_pos.append(cpg_to_coord[cpg])
                if seg_strand == '-':
                    coord_meth_pos.append((seg_end-seg_start)-cpg_to_coord[cpg]-2)

        # alignment to ref elt

        elt_seq = ref.fetch(seg_chrom, seg_start, seg_end)

        if seg_strand == '-':
            elt_seq = rc(elt_seq)

        te_ref_seq = te_ref_seq.upper()
        elt_seq = elt_seq.upper()

        s_ref = skseq.DNA(te_ref_seq)
        s_elt = skseq.DNA(elt_seq)

        aln_res = []

        try:
            aln_res = skalign.local_pairwise_align_ssw(s_ref, s_elt)
        except IndexError: # scikit-bio throws this if no bases align  >:|
            logger.warning('no align on seg: %s:%d-%d' % (seg_chrom, seg_start, seg_end))
            per_bam_results[bam] = None
            continue
        
        coord_ref, coord_elt = aln_res[2]
        
        len_ref = coord_ref[1] - coord_ref[0]
        len_elt = coord_elt[1] - coord_elt[0]

        if len_ref / len(te_ref_seq) < float(args.lenfrac):
            logger.warning('ref align too short on seg: %s:%d-%d (%f)' % (seg_chrom, seg_start, seg_end, len_ref / len(te_ref_seq)))
            per_bam_results[bam] = None
            continue

        if len_elt / len(elt_seq) < float(args.lenfrac):
            logger.warning('elt align too short on seg: %s:%d-%d (%f)' % (seg_chrom, seg_start, seg_end, len_elt / len(elt_seq)))
            per_bam_results[bam] = None
            continue

        tab_msa = aln_res[0]

        elt_to_ref_coords = {}

        pos_ref = coord_ref[0]
        pos_elt = coord_elt[0]

        for pos in tab_msa.iter_positions():
            pos = list(pos)
            b_ref = pos[0]
            b_elt = pos[1]

            if '-' not in pos:
                elt_to_ref_coords[pos_elt] = pos_ref
                pos_ref += 1
                pos_elt += 1

            if b_elt == '-':
                pos_ref += 1

            if b_ref == '-':
                elt_to_ref_coords[pos_elt] = 'na'
                pos_elt += 1

        revised_coord_meth_pos = []
        meth_profile = []

        for pos, meth in zip(coord_meth_pos, smoothed_methfrac):
            if pos not in elt_to_ref_coords:
                continue

            revised_pos = elt_to_ref_coords[pos]

            if revised_pos != 'na':
                revised_coord_meth_pos.append(revised_pos)
                meth_profile.append(meth)

        bam_noext = '.'.join(bam.split('.')[:-1])
        if phase:
            bam_noext += '.phase' + phase
        per_bam_results[bam_noext] = (revised_coord_meth_pos, meth_profile)

    return per_bam_results


def get_meth_calls_wg(args, bam_fn, meth_fn, chrom, seg_start, seg_end, phased, mod):
    methbam = False

    if args.methdb is None:
        methbam = True
    
    if not methbam and not os.path.exists(meth_fn):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % meth_fn)

    conn = None
    c = None

    if not methbam:
        conn = sqlite3.connect(meth_fn)
        c = conn.cursor()

    reads = get_phased_reads(bam_fn, chrom, seg_start, seg_end, tag_untagged=phased, min_mapq=int(args.min_mapq), HP_only=True)

    seg_reads = {}

    if methbam:
        for row in parse_methbam(bam_fn, set(reads), chrom, seg_start, seg_end, meth_thresh=0.8, can_thresh=0.8):
            index, cg_chrom, cg_start, stat, methstate, modname = row

            if mod is not None and mod != modname:
                continue

            if chrom != cg_chrom:
                continue

            if cg_start < seg_start or cg_start > seg_end:
                continue

            cg_seg_start = cg_start - seg_start

            if index not in seg_reads:
                seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname)
            else:
                seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    else:
        for index in reads:
            for row in c.execute("SELECT chrom, pos, stat, methstate, modname FROM methdata WHERE readname='%s' ORDER BY pos" % index):

                cg_chrom, cg_start, stat, methstate, modname = row

                if mod is not None and mod != modname:
                    continue

                if chrom != cg_chrom:
                    continue

                if cg_start < seg_start or cg_start > seg_end:
                    continue

                cg_seg_start = cg_start - seg_start

                if index not in seg_reads:
                    seg_reads[index] = Read(index, cg_seg_start, stat, methstate, modname, phase=reads[index])
                else:
                    seg_reads[index].add_mod(cg_seg_start, stat, methstate, modname)

    if args.max_read_density is not None:
        seg_reads = densecall_filter(seg_reads, max_density=float(args.max_read_density))

    meth_data = {}
    meth_data[1] = dd(list)
    meth_data[2] = dd(list)

    meth_table = {}
    meth_table[1] = dd(dict)
    meth_table[2] = dd(dict)

    if phased:
        for name, read in seg_reads.items():
            if read.phase is None:
                continue

            if read.phase == 'unphased':
                continue

            if read.phase not in ['1','2']:
                continue

            for loc in read.llrs.keys():
                phase = int(read.phase)
                assert phase in (1,2)

                meth_data[phase][loc].append(read.meth_calls[loc])


        for phase in (1,2):
            for loc in meth_data[phase]:
                pos = loc+seg_start
                N = len([call for call in meth_data[phase][loc] if call != 0]) # call 0 == ambiguous
                X = len([call for call in meth_data[phase][loc] if call == 1]) # call 1 == methylated

                if N > 0:
                    meth_table[phase][pos]['chr'] = chrom
                    meth_table[phase][pos]['N'] = N 
                    meth_table[phase][pos]['X'] = X 

    else:
        # pass through as phase 1 if not analysing phases
        phase = 1

        for name, read in seg_reads.items():
            for loc in read.llrs.keys():
                meth_data[phase][loc].append(read.meth_calls[loc])

        for loc in meth_data[phase]:
            pos = loc+seg_start
            N = len([call for call in meth_data[phase][loc] if call != 0])
            X = len([call for call in meth_data[phase][loc] if call == 1])

            if N > 0:
                meth_table[phase][pos]['chr'] = chrom
                meth_table[phase][pos]['N'] = N 
                meth_table[phase][pos]['X'] = X 

    return [meth_table[1], meth_table[2]]


def slide_window(meth_table, sample, width=20, slide=2):
    # used for locus plots, composite plots
    midpt_min = min(meth_table['loc'])
    midpt_max = max(meth_table['loc'])

    win_start = int(midpt_min - width/2)
    win_end = win_start + width

    meth_frac = {}
    meth_n = {}

    while int((win_start+win_end)/2) < midpt_max:
        win_start += slide
        win_end += slide

        meth_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == 1)])
        unmeth_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == -1)])

        midpt = int((win_start+win_end)/2)

        if meth_count + unmeth_count > 0:
            meth_frac[midpt] = meth_count/(meth_count+unmeth_count)
            meth_n[midpt] = meth_count+unmeth_count

    return meth_frac, meth_n


def slide_window_phased(meth_table, phase, width=20, slide=2):
    midpt_min = min(meth_table['loc'])
    midpt_max = max(meth_table['loc'])

    win_start = int(midpt_min - width/2)
    win_end = win_start + width

    meth_frac = {}
    meth_n = {}

    while int((win_start+win_end)/2) < midpt_max:
        win_start += slide
        win_end += slide

        meth_count = len(meth_table.loc[(meth_table['phase'] == phase) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == 1)])
        unmeth_count = len(meth_table.loc[(meth_table['phase'] == phase) & (meth_table['loc'] > win_start) & (meth_table['loc'] < win_end) & (meth_table['call'] == -1)])

        midpt = int((win_start+win_end)/2)

        if meth_count + unmeth_count > 0:
            meth_frac[midpt] = meth_count/(meth_count+unmeth_count)
            meth_n[midpt] = meth_count+unmeth_count

    return meth_frac, meth_n


def smooth(x, window_len=8, window='hanning'):
    # used for locus plots
    ''' modified from scipy cookbook: https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html '''

    assert window_len % 2 == 0, '--smoothwindowsize must be an even number'
    assert x.ndim == 1

    if x.size < window_len:
        logger.warning('cannot smooth segment: fewer data points than window_len')
        return x

    if window_len<3:
        return x

    assert window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']

    s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]]
    
    if window == 'flat': #moving average
        w=np.ones(window_len,'d')
    else:
        w=eval('np.'+window+'(window_len)')

    y=np.convolve(w/w.sum(),s,mode='valid')

    return y[(int(window_len/2)-1):-(int(window_len/2))]


def mask_methfrac(data, cutoff=20):
    # used for locus plots
    data = np.asarray(data)
    data = data > int(cutoff)

    segs = []

    in_seg = False
    seg_start = 0

    for i in range(len(data)):
        if data[i]:
            if in_seg:
                segs.append(list(range(seg_start, i)))

            in_seg = False

        else:
            if not in_seg:
                seg_start = i

            in_seg = True

    if in_seg:
        segs.append(list(range(seg_start, len(data))))

    return segs


def build_genes(gtf, chrom, start, end):
    # used for locus plots
    genes = {}

    for line in gtf.fetch(chrom, start, end):

        chrom, source, feature, start, end, score, strand, frame, attribs = line.split('\t')

        block = [int(start), int(end)]

        attribs = attribs.strip()

        attr_dict = {}

        for attrib in attribs.split(';'):
            if attrib:
                key, val = attrib.strip().split()[:2]
                key = key.strip()
                val = val.strip().strip('"')
                attr_dict[key] = val

        if 'gene_id' not in attr_dict:
            continue

        if 'gene_name' not in attr_dict:
            continue

        ensg = attr_dict['gene_id']
        name = attr_dict['gene_name']

        if ensg not in genes:
            genes[ensg] = Gene(ensg, name)

        if feature == 'exon':
            genes[ensg].add_exon(block)

        if feature == 'CDS':
            genes[ensg].add_cds(block)

        if feature == 'transcript':
            genes[ensg].add_tx(block)

    return genes


def sample_db(db, mod, n=1000000):
    if not os.path.exists(db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % db)

    conn = sqlite3.connect(db)
    c = conn.cursor()

    logger.info('sample %d values from %s where modname="%s"' % (n, db, mod))
    return np.asarray(c.execute('SELECT stat FROM methdata WHERE modname="%s" ORDER BY RANDOM() LIMIT %d' % (mod, n)).fetchall()).flatten()


## tools

def scoredist(args):
    sample = {}
    model = {}

    meth_cutoffs = []
    unmeth_cutoffs = []

    avail_mods = []

    for db in args.db.split(','):
        avail_mods += get_modnames(db)

    avail_mods = list(set(avail_mods))

    if args.mod not in avail_mods:
        logger.warning('mod %s not found, available mods: %s' % (args.mod, ','.join(avail_mods)))
        sys.exit()

    for db in args.db.split(','):
        assert os.path.exists(db), '%s not found' % db
        upper, lower = get_cutoffs(db, args.mod)
        meth_cutoffs.append(upper)
        unmeth_cutoffs.append(lower)

        sample[db] = sample_db(db, args.mod, n=int(args.n))

    meth_cutoffs = list(set(meth_cutoffs))
    unmeth_cutoffs = list(set(unmeth_cutoffs))

    if len(meth_cutoffs) > 1:
        logger.warning('multipe upper (methylation) cutoffs:' + ','.join(map(str, meth_cutoffs)))

    if len(unmeth_cutoffs) > 1:
        logger.warning('multipe upper (unmethylation) cutoffs:' + ','.join(map(str, unmeth_cutoffs)))

    fig = plt.figure()
    ax = fig.add_subplot(111)
    dens = sns.kdeplot(data=sample, palette=args.palette, lw=float(args.lw))

    top = ax.get_ylim()[1]

    ax.vlines(meth_cutoffs, ymin=0, ymax=top, colors='black', linestyles='dashed')
    ax.vlines(unmeth_cutoffs, ymin=0, ymax=top, colors='black', linestyles='dashed')

    xmin, xmax = ax.get_xlim()

    if args.xmin:
        xmin = float(args.xmin)
    
    if args.xmax:
        xmax = float(args.xmax)

    ax.set_xlim(xmin, xmax)

    out_fn = '_'.join(map(os.path.basename, args.db.split(','))) + '.scoredist'

    if args.svg:
        out_fn += '.svg'
    else:
        out_fn += '.png'

    logger.info('plot written to %s' % out_fn)
    plt.savefig(out_fn)


def adjustcutoffs(args):
    if not os.path.exists(args.db):
        sys.exit('methylartist database (%s) does not exist, check that full path is included if not in current working directory.' % args.db)

    avail_mods = get_modnames(args.db)

    if args.mod not in avail_mods:
        logger.warning('mod %s not found, available mods: %s' % (args.mod, ','.join(avail_mods)))
        sys.exit()

    conn = sqlite3.connect(args.db)
    c = conn.cursor()

    logger.info('%s: reset methylation states to 0 for mod %s' % (args.db, args.mod))
    c.execute("UPDATE methdata SET methstate=0 where modname='%s'" % args.mod)

    logger.info('%s: mark sites with stat > %f methylated (1) for mod %s' % (args.db, float(args.methylated), args.mod))
    c.execute('UPDATE methdata SET methstate=1 where stat > %f and modname="%s"' % (float(args.methylated), args.mod))

    logger.info('%s: mark sites with stat < %f unmethylated (-1) for mod %s' % (args.db, float(args.unmethylated), args.mod))
    c.execute('UPDATE methdata SET methstate=-1 where stat < %f and modname="%s"' % (float(args.unmethylated), args.mod))

    logger.info('%s: update cutoff table for mod %s' % (args.db, args.mod))
    c.execute('UPDATE cutoffs SET upper=%f WHERE modname="%s"' % (float(args.methylated), args.mod))
    c.execute('UPDATE cutoffs SET lower=%f WHERE modname="%s"' % (float(args.unmethylated), args.mod))

    conn.commit()

def db_guppy(args):
    if not ont_fast5_api_installed:
        sys.exit('ont_fast5_api is not installed but is required for this function. Please install e.g. via "pip install ont-fast5-api"')

    assert os.path.exists(args.fast5), 'path not found: %s' % args.fast5

    bam_db = args.samplename + '.bamcache.db'

    if os.path.exists(bam_db):
        logger.info('using existing bam cache db %s' % bam_db)

    else:
        logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))
        conn = sqlite3.connect(bam_db)
        c = conn.cursor()

        logger.info('caching %s to %s' % (args.bam, bam_db))
        c.execute('''CREATE TABLE bam (readname text, sam text)''')
        c.execute('''CREATE INDEX read_index ON bam(readname)''')

        for bam_fn in args.bam.split(','):
            bam = pysam.AlignmentFile(bam_fn)

            commit_interval = 100000

            for i, read in enumerate(bam.fetch(), 1):
                if read.is_secondary or read.is_supplementary or read.is_duplicate:
                    continue

                read.query_qualities = None

                read_data = (read.query_name, read.to_string())
                c.execute("INSERT INTO bam VALUES ('%s', '%s')" % read_data)

                if i % commit_interval == 0:
                    logger.info('commiting %d records to %s...' % (commit_interval, bam_db))
                    conn.commit()

            logger.info('commiting remaining records to %s...' % bam_db)

        conn.commit()
        conn.close()

    fast5s = []

    for fn in os.listdir(args.fast5):
        if fn.endswith('.fast5'):
            fast5s.append(fn)

    logger.info('found %d fast5 files in %s' % (len(fast5s), args.fast5))

    logger.info('fetching base modifications...')

    pool = mp.Pool(processes=int(args.procs))

    results = []

    for fast5 in fast5s:
        res = pool.apply_async(guppy_f5_fetch, [fast5, bam_db, args])
        results.append(res)


    outfiles = []

    for res in results:
        outfiles.append(res.get())

    db_fn = args.samplename + '.guppy.db'

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    conn = sqlite3.connect(db_fn)
    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for fn in outfiles:
        if fn is None:
            continue

        with open(fn) as tsv:
            logger.info('loading %s into %s' % (fn, db_fn))
            csv_reader = csv.DictReader(tsv, delimiter='\t')

            for row in csv_reader:
                readname = row['readname']
                chrom    = row['chrom']
                pos      = int(row['pos'])
                strand   = row['strand']

                mod_prob = float(row['mod_log_prob'])
                can_prob = float(row['can_log_prob'])
                minprob  = float(args.minprob)
                mod_base = args.modname

                llr = float(row['mod_log_prob']) - float(row['can_log_prob'])

            # adjust position of - strand calls based on mod motif size (default = 2 as CG is probably the most frequent use case)
            if strand == '-':
                pos -= (int(args.motifsize)-1)

                methcall = 0

                if mod_prob >= minprob:
                    methcall = 1

                if can_prob >= minprob:
                    methcall = -1

                assert (mod_prob + can_prob) < 1.01 # allow for some rounding error

                ins_data = (chrom, pos, strand, readname, llr, methcall, mod_base)
                c.execute("INSERT INTO methdata VALUES ('%s', %d, '%s', '%s', %.2f, %d, '%s')" % ins_data)

            c.execute("INSERT INTO modnames VALUES ('%s')" % args.modname)

            if not args.append:
                c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (llr, -1*llr, args.modname))

            logger.info('commiting records from %s to %s' % (fn, db_fn))
            conn.commit()

    conn.close()        
    logger.info('finished.')


def db_megalodon(args):
    basename = '.'.join(args.methdata.split('.')[:-1])
    if basename == 'per_read_modified_base_calls' and not args.db:
        sys.exit('default megalodon filename (per_read_modified_base_calls) is not informative: please specify a database name with --db')    

    db_fn = basename + '.megalodon.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        csv_reader = csv.DictReader(methdata, delimiter='\t')

        progress_interval = 1000000

        modnames = {}

        ins_data = []

        minprob = float(args.minprob)

        for i, row in enumerate(csv_reader):
            readname = row['read_id']
            chrom    = row['chrm']
            pos      = int(row['pos'])
            strand   = row['strand']
            mod_prob = np.exp(float(row['mod_log_prob']))
            can_prob = np.exp(float(row['can_log_prob']))
            
            mod_base = row['mod_base']

            modnames[mod_base] = True

            # adjust position of - strand calls based on mod motif size (default = 2 as CG is probably the most frequent use case)
            if strand == '-':
                pos -= (int(args.motifsize)-1)

            methcall = 0

            if mod_prob >= minprob:
                methcall = 1

            if can_prob >= minprob:
                methcall = -1

            ins_data.append((chrom, pos, strand, readname, mod_prob, methcall, mod_base))

            if i % progress_interval == 0:
                conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
                ins_data = []
                logger.info('processed %d records from %s' % (i, tsv))

        for mod in modnames:
            c.execute("INSERT INTO modnames VALUES ('%s')" % mod)
            c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (minprob, 1-minprob, mod))

        if len(ins_data) > 0:
            conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
            ins_data = []

        logger.info('commiting records from %s to %s' % (tsv, db_fn))
        conn.commit()

    conn.close()


def db_custom(args):
    basename = '.'.join(args.methdata.split('.')[:-1])
    if basename == 'per_read_modified_base_calls' and not args.db:
        sys.exit('default megalodon filename (per_read_modified_base_calls) is not informative: please specify a database name with --db')    

    db_fn = basename + '.custom.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        progress_interval = 1000000

        modnames = {}

        ins_data = []

        c_readname = int(args.readname)
        c_chrom    = int(args.chrom)
        c_pos      = int(args.pos)
        c_strand   = int(args.strand)
        c_modprob  = int(args.modprob)

        c_canprob = None

        if args.canprob is not None:
            c_canprob  = int(args.canprob)

        mod_base = None

        if args.modbasecol is not None:
            c_modbasecol = int(args.modbasecol)

        minmodprob = float(args.minmodprob)
        mincanprob = float(args.minmodprob)

        if args.mincanprob is not None:
            mincanprob = float(args.mincanprob)


        for i, row in enumerate(methdata):
            if args.header and i == 0:
                continue

            cols = row.strip().split()

            if args.delimiter is not None:
                cols = row.strip().split(args.delimiter)

            if len(cols) < 5:
                sys.exit('table %s has < 5 columns at line %d' % (tsv, i))

            readname = cols[c_readname]
            chrom    = cols[c_chrom]
            pos      = int(cols[c_pos])
            strand   = cols[c_strand]
            mod_prob = float(cols[c_modprob])
            can_prob = 1.0-mod_prob

            if c_canprob is not None:
                can_prob = float(cols[c_canprob])

            if args.modbase is not None:
                mod_base = args.modbase

            if mod_base is None:
                sys.exit('must specify either --modbase or --modbasecol')

            modnames[mod_base] = True

            # adjust position of - strand calls based on mod motif size (default = 2 as CG is probably the most frequent use case)
            if strand == '-':
                pos -= (int(args.motifsize)-1)

            methcall = 0

            if mod_prob >= minmodprob:
                methcall = 1

            if can_prob >= mincanprob:
                methcall = -1

            ins_data.append((chrom, pos, strand, readname, mod_prob, methcall, mod_base))

            if i % progress_interval == 0:
                conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
                ins_data = []
                logger.info('processed %d records from %s' % (i, tsv))

        for mod in modnames:
            c.execute("INSERT INTO modnames VALUES ('%s')" % mod)

        c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (float(minmodprob), 1-float(mincanprob), mod))

        if len(ins_data) > 0:
            conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
            ins_data = []

        logger.info('commiting records from %s to %s' % (tsv, db_fn))
        conn.commit()

    conn.close()


def db_nanopolish(args):
    basename = '.'.join(args.methdata.split('.')[:-1])

    if basename.endswith('.tsv'):
        basename = '.'.join(basename.split('.')[:-1])

    db_fn = basename + '.nanopolish.db'

    if args.db:
        if args.db.endswith('.db'):
            db_fn = args.db
        else:
            db_fn = args.db + '.db'

    logger.info('database output name: ' + db_fn)

    if os.path.exists(db_fn) and not args.append:
        sys.exit('database %s already exists' % db_fn)

    if args.append and not os.path.exists(db_fn):
        sys.exit('database %s does not exist and --append has been called' % db_fn)

    conn = sqlite3.connect(db_fn)

    c = conn.cursor()

    if not args.append:
        c.execute('''CREATE TABLE methdata (chrom text, pos integer, strand text, readname text, stat real, methstate integer, modname text)''')
        c.execute('''CREATE TABLE modnames (mod text)''')
        c.execute('''CREATE INDEX read_index ON methdata(readname)''')
        c.execute('''CREATE TABLE cutoffs (upper real, lower real, modname text)''')

    else:
        logger.info('appending records to %s' % db_fn)

    for tsv in args.methdata.split(','):
        methdata = None

        if tsv.endswith('.gz'):
            methdata = gzip.open(tsv, 'rt')
        else:
            methdata = open(tsv)

        logger.info('parsing ' + tsv)

        csv_reader = csv.DictReader(methdata, delimiter='\t')

        progress_interval = 500000

        ins_data = []

        for i, row in enumerate(csv_reader, 1):
            try:
                r_start  = int(row['start'])
                llr      = float(row['log_lik_ratio'])
                seq      = row['sequence']
                mod_base = args.modname
            except:
                logger.warning('bad line %d: %s' % (i, str(row)))
                continue
                
            if args.scalegroup:
                llr = llr/float(row['num_motifs'])

            methcall = 0

            if llr > float(args.thresh):
                methcall = 1

            elif llr < float(args.thresh)*-1:
                methcall = -1

            # get per-CG position (nanopolish/calculate_methylation_frequency.py)
            if args.motif not in seq:
                sys.exit('motif %s not found in kmer %s, please check --motif setting' % (args.motif, seq))

            cg_pos = seq.find(args.motif)
            first_cg_pos = cg_pos

            while cg_pos != -1:
                cg_start = r_start + cg_pos - first_cg_pos
                cg_pos = seq.find(args.motif, cg_pos + 1)

                ins_data.append((row['chromosome'], cg_start, row['strand'], row['read_name'], llr, methcall, mod_base))

            if i % progress_interval == 0:
                conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
                ins_data = []
                logger.info('processed %d records from %s' % (i, tsv))

        if len(ins_data) > 0:
            conn.executemany('INSERT INTO methdata VALUES (?,?,?,?,?,?,?)', ins_data)
            ins_data = []

        c.execute("INSERT INTO modnames VALUES ('%s')" % args.modname)

        c.execute("INSERT INTO cutoffs VALUES ('%.4f', '%.4f', '%s')" % (float(args.thresh), float(args.thresh)*-1, args.modname))

        logger.info('commiting records from %s to %s' % (tsv, db_fn))

        conn.commit()

    conn.close()


def segmeth(args):
    '''
    segment methylation stats over genomic intervals
    '''

    stats = [
    'meth_calls',
    'unmeth_calls',
    'no_calls',
    'methfrac',
    'readcount'
    ]

    mod_names = []

    data = dd(list)

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams')

    methbam = False

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both')

        with open(args.data) as _:
            for line in _:
                bam, meth = line.strip().split()
                for m_db in meth.split(','):
                    data[bam].append(m_db)

                    for m in get_modnames(m_db):
                        mod_names.append(m)

    if args.bams is not None:
        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')

        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    bams.append(line.strip())

        for bam in bams:
            if ':' in bam:
                bam, _ = bam.split(':')

            if not os.path.exists(bam+'.bai'):
                sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

        mod_names = mods_methbam(bam)

    mod_names = list(set(mod_names))

    logger.info('found mod names: %s' % ','.join(mod_names))

    base_names = ['.'.join(bam.split('.')[:-1]) for bam in data]

    if args.phased:
        phased_base_names = []
        for bn in base_names:
            phased_base_names.append(bn+'_ph1')
            phased_base_names.append(bn+'_ph2')
        base_names = phased_base_names

    data_basename = None

    if methbam:
        data_basename = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1])
        if len(args.bams.split(',')) > 1:
            data_basename += '.cohort'

    else:
        data_basename = '.'.join(os.path.basename(args.data).split('.')[:-1])    

    ivl_basename = '.'.join(os.path.basename(args.intervals).split('.')[:-1])
    outfn = '.'.join((ivl_basename, data_basename))

    if args.excl_ambig:
        outfn += '.excl_ambig'

    if args.spanning_only:
        outfn += '.spanning_only'
    
    if args.phased:
        outfn += '.phased'

    outfn += '.segmeth.tsv'

    out = open(outfn, 'w')

    logger.info('segmeth output filename: %s' % outfn)

    pool = mp.Pool(processes=int(args.procs))

    results = []

    phases = [None]

    if args.phased:
        phases = ['1', '2']

    for bam_fn, meth_fn in data.items():
        for phase in phases:
            base_name = '.'.join(bam_fn.split('.')[:-1])

            if phase is not None:
                base_name += '_ph' + phase

            with open(args.intervals) as _:
                for line in _:
                    c = line.strip().split()
                    chrom, seg_start, seg_end = c[:3]

                    seg_name = 'NA'
                    seg_strand = 'NA'

                    if len(c) > 3:
                        seg_name = c[3]

                    if len(c) > 4:
                        seg_strand = c[4]

                    seg_start = int(seg_start)
                    seg_end = int(seg_end)

                    res = pool.apply_async(get_segmeth_calls, [args, bam_fn, mod_names, meth_fn, chrom, seg_start, seg_end, seg_name, seg_strand, phase, methbam])

                    results.append((res, base_name))


    meth_segs = dd(dict)

    for res, base_name in results:
        meth_result, seg = res.get()

        if meth_result is None:
            continue

        seg_id = '%s:%d-%d' % seg[:3]

        seg_chrom, seg_start, seg_end, seg_name, seg_strand, read_count = map(str, seg)

        meth_segs[seg_id]['seg_id']     = seg_id
        meth_segs[seg_id]['seg_chrom']  = seg_chrom
        meth_segs[seg_id]['seg_start']  = seg_start
        meth_segs[seg_id]['seg_end']    = seg_end
        meth_segs[seg_id]['seg_name']   = seg_name
        meth_segs[seg_id]['seg_strand'] = seg_strand

        if seg_name == 'NA':
            meth_segs[seg_id]['seg_name'] = 'NoName'

        if seg_strand == 'NA':
            meth_segs[seg_id]['seg_strand'] = '.'

        for modname, meth_data in meth_result.items():
            no_calls = 0
            meth_calls = 0
            unmeth_calls = 0

            if -1 in meth_data:
                unmeth_calls = meth_data[-1]

            if 0 in meth_data:
                no_calls = meth_data[0]

            if 1 in meth_data:
                meth_calls = meth_data[1]

            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_meth_calls'] = meth_calls
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_unmeth_calls'] = unmeth_calls
            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_no_calls'] = no_calls

            if meth_calls+unmeth_calls == 0:
                 meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_methfrac'] = 'NaN'
            else:
                meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_methfrac'] = meth_calls/float(meth_calls+unmeth_calls)

            meth_segs[seg_id][os.path.basename(base_name) + '_' + modname + '_readcount'] = read_count

    header = ['seg_id', 'seg_chrom', 'seg_start', 'seg_end', 'seg_name', 'seg_strand']

    for bn in base_names:
        for m in mod_names:
            for s in stats:
                header.append('%s_%s_%s' % (os.path.basename(bn), m, s))

    out.write('\t'.join(header) + '\n')

    for mseg in meth_segs:
        output = [] 
        for h in header:
            if h not in meth_segs[mseg]:
                logger.warning('no calls for sample %s in segment %s, skipped.' % (h, mseg))
                break

            output.append(str(meth_segs[mseg][h]))
        out.write('\t'.join(output) + '\n')

    out.close()


def segplot(args):
    '''
    strip plots / violin plots of segmented methylation data over categories by sample
    '''

    if args.ridge:
        v = sns.__version__.split('.')
        if int(v[1]) < 11:
            sys.exit('--ridge requires seaborn 0.11.2 or later')
        if int(v[2]) < 2:
            sys.exit('--ridge requires seaborn 0.11.2 or later')

    data = pd.read_csv(args.segmeth, sep='\t', header=0, index_col=0)

    samples = []
    mods = []

    for c in data.columns:
        if c.endswith('_meth_calls'):
            sample_mod = c.replace('_meth_calls', '')
            sample = '_'.join(sample_mod.split('_')[:-1])
            mod = sample_mod.split('_')[-1]            
            samples.append(sample)
            if mod not in mods:
                mods.append(mod)

    logger.info('available samples: %s' % ','.join(samples))
    logger.info('available mods: %s' % ','.join(mods))

    if args.samples:
        user_samples = []
        for s in args.samples.split(','):
            if s not in samples:
                sys.exit('%s not in sample list!' % s)
            user_samples.append(s)

        samples = user_samples

    if args.mods:
        user_mods = []
        for m in args.mods.split(','):
            if m not in mods:
                sys.exit('%s not in mod list!' % m)
            user_mods.append(m)

        mods = user_mods

    samples_mods = [s + '_' + m for s, m in itertools.product(samples, mods)]

    logger.info('sample + mod permutations: %s' % ','.join(samples_mods))

    for s in samples:
        for m in mods:
            sm = s + '_' + m
            data[sm + '_methfrac'] = data[sm + '_meth_calls']/(data[sm + '_meth_calls']+data[sm + '_unmeth_calls'])

    useable = []

    for seg in data.index:
        use_seg = True

        for s in samples:
            for m in mods:
                sm = s + '_' + m

                if (data[sm + '_meth_calls'].loc[seg] + data[sm + '_unmeth_calls'].loc[seg]) < int(args.mincalls):
                    use_seg = False
                    continue

                if data[sm+'_readcount'].loc[seg] < int(args.minreads):
                    use_seg = False
                    continue

        if use_seg:
            useable.append(seg)

    data = data.loc[useable]

    logger.info('useable sites: %d' % len(useable))

    plot_data = dd(dict)

    order = []

    for seg in data.index:
        for s in samples:
            for m in mods:
                sm = s + '_' + m
                uid = seg + ':' + sm
                plot_data[uid]['sample']  = sm
                plot_data[uid]['modbase'] = data[sm + '_methfrac'].loc[seg]
                plot_data[uid]['group']   = data['seg_name'].loc[seg]

                if args.ridge:
                    plot_data[uid]['samplegroup'] = sm + ' ' + plot_data[uid]['group']

                if plot_data[uid]['group'] not in order:
                    order.append(plot_data[uid]['group'])

    plot_data = pd.DataFrame.from_dict(plot_data).T
    plot_data = pd.DataFrame(plot_data.to_dict())

    basename = '.'.join(args.segmeth.split('.')[:-1])

    basename += '.mc%d' % int(args.mincalls)
    basename += '.mr%d' % int(args.minreads)

    plot_data.to_csv(basename+'.segplot_data.csv')
    logger.info('plot data written to %s.segplot_data.csv' % basename)

    pt_sz = int(args.pointsize)

    if args.categories is not None:
        order = args.categories.split(',')
        plot_data = plot_data[plot_data['group'].isin(order)]

    if args.violin:
        if args.group_by_annotation:
            sns_plot = sns.violinplot(x='group', y='modbase', data=plot_data, hue='sample', dodge=True, jitter=True, order=order, hue_order=samples_mods, palette=args.palette)
        else:
            sns_plot = sns.violinplot(x='sample', y='modbase', data=plot_data, hue='group', dodge=True, jitter=True, hue_order=order, palette=args.palette)

        basename += '.violin'

    elif args.ridge:
        sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

        sample_order = CategoricalDtype(samples_mods, ordered=True)
        plot_data['sample'] = plot_data['sample'].astype(sample_order)

        if args.categories:
            cat_order = CategoricalDtype(order, ordered=True)
            plot_data['group'] = plot_data['group'].astype(cat_order)

        plot_data = plot_data.sort_values(['sample','group'])

        if args.group_by_annotation:
            plot_data = plot_data.sort_values(['group','sample'])

        sns_plot = sns.FacetGrid(plot_data, row='samplegroup', hue='sample', aspect=15, height=.5, palette=args.palette)
        sns_plot.map(sns.kdeplot, 'modbase', bw_adjust=.5, clip_on=False, fill=True, alpha=float(args.ridge_alpha), linewidth=1.5)
        sns_plot.map(sns.kdeplot, 'modbase', clip_on=False, color='w', lw=2, bw_adjust=.5)
        sns_plot.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)
        sns_plot.figure.subplots_adjust(hspace=float(args.ridge_spacing))

        seen_samples = set()
        seen_groups = set()

        s_col = sns.color_palette(args.palette, n_colors=len(set(plot_data['sample'])))

        s = 0

        for i in range(len(set(plot_data['samplegroup']))):
            ax = sns_plot.axes[i,0]
            samplename, groupname = ax.title.get_text().split('=')[-1].strip().split()

            if args.group_by_annotation and groupname not in seen_groups:
                ax.text(0, .15, groupname, color='k', ha='left', va='center', transform=ax.transAxes)
            
            if not args.group_by_annotation:
                ax.text(0, .15, groupname, color='k', ha='left', va='center', transform=ax.transAxes)
            
            if samplename not in seen_samples:
                ax.text(0, -.15, samplename, size=8, weight='bold', color=s_col[s], ha='left', va='center', transform=ax.transAxes)
                s += 1

            seen_groups.add(groupname)
            seen_samples.add(samplename)

        sns_plot.set_titles("")
        sns_plot.set(yticks=[], ylabel="")
        sns_plot.despine(bottom=True, left=True)

        basename += '.ridge'

    else:
        if args.group_by_annotation:
            sns_plot = sns.stripplot(x='group', y='modbase', data=plot_data, hue='sample', dodge=True, jitter=True, size=pt_sz, order=order, hue_order=samples_mods, palette=args.palette)
        else:
            sns_plot = sns.stripplot(x='sample', y='modbase', data=plot_data, hue='group', dodge=True, jitter=True, size=pt_sz, hue_order=order, palette=args.palette)

    if args.group_by_annotation:
        basename += '.group_by_annotation'

    if not args.ridge:
        sns_plot.set_xlabel("")
        sns_plot.set_ylabel(args.ylabel)

        if args.tiltlabel:
            sns_plot.set_xticklabels(sns_plot.get_xticklabels(), rotation=45)

        sns_plot.set_ylim(float(args.ymin),float(args.ymax))

        if args.width is None:
            args.width = 1 + len(order) * len(samples_mods)
            logger.info('auto set --width: %d' % args.width)

    fig = sns_plot.figure

    if not args.ridge:
        fig.set_size_inches(float(args.width), float(args.height))

    if args.svg:
        fig.savefig(basename+'.segplot.svg', bbox_inches='tight')
        logger.info('plot saved to %s.segplot.svg' % basename)
    else:
        fig.savefig(basename+'.segplot.png', bbox_inches='tight')
        logger.info('plot saved to %s.segplot.png' % basename)


def locus(args):
    '''
    plot methylation profile of a region in CpG space
    '''
    # set up

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos= args.interval.split(':')
    elt_start, elt_end = pos.split('-')

    elt_start = int(elt_start.replace(',',''))
    elt_end = int(elt_end.replace(',',''))
    
    data = dd(list)
    user_colours = {}

    markers = ['.',',','o','v','^','<','>','8','s','p','P','*','h','H','X','D','d']

    if args.readmarker not in markers:
        sys.exit('%s is not a valid marker type, valid types: %s' % (args.readmarker, ','.join(markers)))

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams')

    methbam = False

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both')

        with open(args.data) as _:
            for line in _:
                c = line.strip().split()
                if len(c) < 2:
                    logger.warning("required fields for -d/--data are: .bam file and methylation .db (generated by methylartist)")
                    sys.exit()

                bam, meth_db = c[:2]
                for m_db in meth_db.split(','):
                    data[bam].append(m_db)

                if len(c) == 3:
                    user_colours[bam] = c[2]

    if args.bams is not None:
        logger.info('mod motif size (--motifsize) = %d (ensure this is correct for your data)' % int(args.motifsize))
        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')

        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    bams.append(line.strip())

        for bam in bams:
            if ':' in bam:
                bam, ucol = bam.split(':')
                user_colours[bam] = ucol

            if not os.path.exists(bam+'.bai'):
                sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

    # table for plotting

    meth_table = dd(dict)

    sample_order = []

    reads = {}
    orig_bam = {}

    use_mods = None
    if args.mods:
        use_mods = args.mods.split(',')

    phases = {}

    if args.phased:
        for bam in data:
            phased_reads = get_phased_reads(bam, chrom, elt_start, elt_end, tag_untagged=args.include_unphased, min_mapq=int(args.min_mapq), HP_only=args.ignore_ps)
            phases[bam] = list(set([p for p in phased_reads.values() if p]))
            logger.info('phases for bam %s: %s' % (bam, ','.join(phases[bam])))

    for bam, meth_dbs in data.items():
        if meth_dbs is None:
            mods = mods_methbam(bam)
            logger.info('found mods: %s in bam %s' % (','.join(mods), bam))

        else:
            for meth_db in meth_dbs:
                mods = sorted(get_modnames(meth_db))
                logger.info('found mods: %s in db %s' % (','.join(mods), meth_db))

        for mod in mods:
            if use_mods:
                if mod not in use_mods:
                    logger.info('skipping %s, not specified in -m/--mods %s' % (mod, args.mods))
                    continue
            else:
                use_mods = mods

            if args.phased:
                for phase in phases[bam]:
                    bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + phase + '.' + mod
                    orig_bam[bamname] = bam
                    reads[bamname] = get_meth_locus(args, bam, meth_dbs, mod, phase=phase, methbam=methbam, HP_only=args.ignore_ps)

                    for name, read in reads[bamname].items():
                        for loc in read.llrs.keys():
                            uuid = str(uuid4())
                            meth_table[uuid]['loc'] = loc
                            meth_table[uuid]['modstat'] = read.llrs[loc]
                            meth_table[uuid]['read'] = name
                            meth_table[uuid]['sample'] = bamname
                            meth_table[uuid]['call'] = read.meth_calls[loc]

                            if bamname not in sample_order:
                                sample_order.append(bamname)

            else:
                bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + mod
                orig_bam[bamname] = bam
                reads[bamname] = get_meth_locus(args, bam, meth_dbs, mod, methbam=methbam)

                for name, read in reads[bamname].items():
                    for loc in read.llrs.keys():
                        uuid = str(uuid4())
                        meth_table[uuid]['loc'] = loc
                        meth_table[uuid]['modstat'] = read.llrs[loc]
                        meth_table[uuid]['read'] = name
                        meth_table[uuid]['sample'] = bamname
                        meth_table[uuid]['call'] = read.meth_calls[loc]

                        if bamname not in sample_order:
                            sample_order.append(bamname)

    meth_table = pd.DataFrame.from_dict(meth_table).T

    if 'loc' not in meth_table:
        sys.exit('%s: insufficient coverage for plot' % args.interval)

    meth_table['loc'] = pd.to_numeric(meth_table['loc'])
    meth_table['modstat'] = pd.to_numeric(meth_table['modstat'])

    meth_table['orig_loc'] = meth_table['loc']
    meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

    if len(meth_table['orig_loc']) == 0:
        sys.exit('%s: insufficient coverage for plot' % args.interval)        

    # optional mincall filter
    args.mincalls = int(args.mincalls)

    if args.mincalls > 0:
        drop_ix = []

        logger.info('assessing coverage per --mincalls, per-read sites to consider: %d' % len(meth_table['loc']))
        tick = 5000
        for i, new_loc in enumerate(meth_table['loc'], 1):
            for sample in sample_order:
                call_count = len(meth_table.loc[(meth_table['sample'] == sample) & (meth_table['loc'] == new_loc)])

                if call_count < args.mincalls:
                    drop_ix += list(meth_table.loc[(meth_table['loc'] == new_loc)].index)
            
            if i % tick == 0:
                logger.info('%s %s: processed %d sites' % (args.data, args.interval, i))
        
        drop_ix = list(set(drop_ix))
        logger.info('dropping %d positions' % len(drop_ix))
        meth_table.drop(drop_ix, inplace=True)

        meth_table['loc'] = ss.rankdata(meth_table['loc'], method='dense')

        if len(meth_table['orig_loc']) == 0:
            sys.exit('%s: insufficient coverage for plot' % args.interval)

    # create mod space

    coord_to_cpg = {}
    for orig_loc, new_loc in zip(meth_table['orig_loc'], meth_table['loc']):
        coord_to_cpg[orig_loc] = new_loc

    # calibrate plotting parameters
    logger.info('region size: %d' % (elt_end-elt_start))

    modspace_n = len(list(set(coord_to_cpg.values())))
    logger.info('mod space positions: %d' % modspace_n)

    if args.modspace is not None:
        logger.info('user set --modspace %d, mod_n:ticks: %.3f' % (int(args.modspace), modspace_n/float(args.modspace)))
        args.modspace = int(args.modspace)
    else:
        if modspace_n <= 1200:
            args.modspace = 1
        else:
            args.modspace = round(modspace_n/1200)

        logger.info('auto set --modspace %d' % args.modspace)

    if args.smoothwindowsize is not None:
        logger.info('user set --smoothwindowsize %d, mod_n:smoothwindowsize: %.3f' % (int(args.smoothwindowsize), modspace_n/float(args.smoothwindowsize)))
        args.smoothwindowsize = int(args.smoothwindowsize)

        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1
            logger.info('-s/--smoothwindowsize must be an even integer, adjusted value: %d' % args.smoothwindowsize)

    else:
        args.smoothwindowsize = round(0.0167*modspace_n + 18)

        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1

        logger.info('auto set --smoothwindowsize %d' % args.smoothwindowsize)

    # highlights

    h_start = []
    h_end = []
    h_cpg_start = []
    h_cpg_end = []
    h_colors = []

    if args.highlight:
        h_colors = sns.color_palette(args.highlightpalette, n_colors=len(args.highlight.split(',')))

    if args.highlight:
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_s, h_e = map(int, h.split('-'))
            h_start.append(h_s - elt_start)
            h_end.append(h_e - elt_start)

            h_cpg_start.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_start[-1]))])
            h_cpg_end.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_end[-1]))])

    if args.highlight_bed:
        need_colour = 0

        with open(args.highlight_bed) as h_bed:
            for line in h_bed:
                c = line.strip().split()
                assert len(c) >= 3, 'malformed line in --highlight_bed: %s' % line.strip()
                assert c[0] == chrom, 'all entries in --highlight_bed must be on the same chromosome as -i/--interval'

                h_s, h_e = map(int, c[1:3])

                colour = None

                if len(c) > 3:
                    colour = c[3]
                else:
                    need_colour += 1
                
                h_start.append(h_s - elt_start)
                h_end.append(h_e - elt_start)

                h_cpg_start.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_start[-1]))])
                h_cpg_end.append(coord_to_cpg[min(meth_table['orig_loc'], key=lambda x:abs(x-h_end[-1]))])

                h_colors.append(colour)

        more_colours = sns.color_palette(args.highlightpalette, n_colors=need_colour)

        for i, c in enumerate(h_colors):
            if c is None:
                h_colors[i] = more_colours.pop()

    # mask

    readmask = []
    if args.readmask:
        for ivl in args.readmask.split(','):
            if ':' in ivl:
                ivl = ivl.split(':')[1]
            assert '-' in ivl, 'malformed --readmask interval(s): %s' % args.readmask

            readmask.append(list(map(int, ivl.split('-'))))

    # set up plot

    fig = plt.figure()
    gs = gridspec.GridSpec(5,1,height_ratios=[1,5,1,3,3], hspace=0)

    img_w = 16
    img_h = 8

    if args.width:
        img_w = float(args.width)

    if args.height:
        img_h = float(args.height)

    fig.set_size_inches(img_w, img_h)

    if args.panelratios:
        p_ratios = list(map(int, args.panelratios.split(',')))
        gs = gridspec.GridSpec(5,1,height_ratios=p_ratios, hspace=0)

    sample_color = {}
    for i, sample in enumerate(sample_order):
        sample_color[sample] = sns.color_palette(args.samplepalette, n_colors=len(sample_order))[i]

    for sample in sample_color:
        if orig_bam[sample] in user_colours:
            sample_color[sample] = user_colours[orig_bam[sample]]

    if args.color_by_hp:
        if args.phased:
            hp_color = {}
            for hp in ('1','2'):
                hp_color[hp] = sns.color_palette(args.samplepalette, n_colors=2)[int(hp)-1]
            
            for sample in sample_color:
                hp = sample.split('.')[-2]
                assert hp in ('1','2')
                sample_color[sample] = hp_color[hp]

        else:
            logger.warning('--color_by_hp has no effect without --phase')


    # plot genes

    ax0 = plt.subplot(gs[0])

    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.xaxis.set_ticks_position('top')

    gtf = None
    genes = [] 

    if args.gtf is not None:
        logger.info('building genes plot...')
        gtf = pysam.Tabixfile(args.gtf)
        genes = build_genes(gtf, chrom, elt_start, elt_end)

    exon_patches = []
    tx_lines = []

    genes_of_interest = []

    if args.genes is not None:
        genes_of_interest = args.genes.strip().split(',')

    genemap = dd(Intersecter)

    if genes_of_interest:
        new_genes = {}
        for ensg in genes:
            if genes[ensg].name in genes_of_interest:
                new_genes[ensg] = genes[ensg]
        
        genes = new_genes

    gene_colours = sns.color_palette(args.genepalette, n_colors=len(genes))

    for i, ensg in enumerate(genes):
        if not genes[ensg].has_tx():
            logger.warning('no transcript for gene: %s' % genes[ensg].name)
            continue

        logger.info('gene in region: %s' % genes[ensg].name)

        y = 1

        while genemap[y].find(genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start):
            y += 1

        tx_lines.append(matplotlib.lines.Line2D([genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start], [0.4+y, 0.4+y], color=gene_colours[i], zorder=2))

        genes[ensg].merge_exons()
        for exon_start, exon_end in genes[ensg].exons:
            exon_len = exon_end - exon_start
            exon_patches.append(matplotlib.patches.Rectangle([exon_start-elt_start, y], exon_len, 0.8, edgecolor=gene_colours[i], facecolor=gene_colours[i], zorder=3))

        genemap[y].add_interval(Interval(genes[ensg].tx_start-elt_start, genes[ensg].tx_end-elt_start))

        if args.labelgenes:
            lg_x  = max(genes[ensg].tx_start-elt_start, 0)
            gtxt  = ax0.text(lg_x, y+0.8, genes[ensg].name, bbox=dict(boxstyle="round,pad=0.3", fc="lavender", alpha=0.5, lw=0), zorder=4)
            bb_w  = gtxt.get_tightbbox(renderer=fig.canvas.get_renderer()).width
            fig_w = fig.get_size_inches()[0]*fig.dpi
            txt_w = bb_w/fig_w*(elt_end-elt_start)
            gtxt.set_x(lg_x-txt_w/2)

    gene_height = max(3, len(genemap))

    if args.labelgenes:
        gene_height = max(6, len(genemap))

    ax0.set_ylim(0, gene_height)
    ax0.set_yticks([])

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_start):
                h_e = h_end[h]
                h_color = h_colors[h]
                ymin = min(ax0.get_ybound())
                yheight = max(ax0.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax0.add_patch(h)

    for p in exon_patches:
        ax0.add_patch(p)

    for tx in tx_lines:
        ax0.add_line(tx)

    # per-read plot

    logger.info('building read alignment plot...')
    ax1 = plt.subplot(gs[1])
    ax1.set_xticks([])
    ax1.set_yticks([])

    readstack = dd(list)

    max_y  = 1
    pack_y = 1

    masked_count = 0

    for bamname in reads:
        fetch_reads_bam = pysam.AlignmentFile(orig_bam[bamname])
        pos_cache = {}

        for read in fetch_reads_bam.fetch(chrom, elt_start, elt_end):
            if read.mapq < int(args.min_mapq):
                continue

            if read.is_supplementary or read.is_secondary or read.is_duplicate:
                if not args.allreads:
                    continue

            masked = False
            if len(readmask) > 0:
                for mask_start, mask_end in readmask:
                    if read.reference_start >= mask_start and read.reference_end <= mask_end:
                        logger.debug('masked read: %s' % read.query_name)
                        masked_count += 1
                        masked = True
            
            if masked:
                continue

            if read.query_name not in pos_cache:
                pos_cache[read.query_name] = ([read.reference_start], [read.reference_end])
            else:
                pos_cache[read.query_name][0].append(read.reference_start)
                pos_cache[read.query_name][1].append(read.reference_end)

        for readname, read in reads[bamname].items():
            if readname not in pos_cache:
                logger.debug('read %s not found in %s (skipped)' % (readname, bamname))
                continue

            if read.call_count == 0:
                continue

            read.ypos = max_y
            max_y += 1

            read.starts, read.ends = pos_cache[readname]
            readstack[bamname].append(read)

        fetch_reads_bam.close()

    if args.readmask:
        logger.info('masked %d reads due to --readmask' % masked_count)

    # read packing

    for bamname in readstack:
        reads = readstack[bamname]

        y = dd(list)

        for read in readstack[bamname]:
            y[read.ypos].append(read)

        for p in y:
            for q in y:
                if p == q:
                    continue

                for read_q in y[q]:
                        move = True

                        for read_p in y[p]:
                            if read_q.overlap(read_p):
                                move = False

                        if move:
                            y[p].append(read_q)
                            y[q].remove(read_q)

        for p in y:
            if len(y[p]) > 0:
                for read in sorted(y[p], key=lambda r: min(r.starts)):
                    read.ypos = pack_y
                pack_y += 1

    ax1.set_ylim(0,pack_y+1)

    for bamname in readstack:
        for read in readstack[bamname]:
            rm = args.readmarker
            ms = float(args.readmarkersize)
            lw = float(args.readlinewidth)
            la = float(args.readlinealpha)
            ma = float(args.markeralpha)
            uec = sample_color[bamname]

            if args.readopenmarkeredgecolor is not None:
                uec = args.readopenmarkeredgecolor

            for call_pos, call in read.meth_calls.items():
                if call == -1:
                    ax1.plot(call_pos, read.ypos, marker=rm, fillstyle='full', mec=uec, mfc='white', markersize=ms, zorder=3, alpha=ma)

                if call == 1:
                    ax1.plot(call_pos, read.ypos, marker=rm, fillstyle='full', mec='black', mfc='black', markersize=ms, zorder=3, alpha=ma)

            for i in range(len(read.starts)):
                readline_start = max(read.starts[i], elt_start) - elt_start
                readline_end   = min(read.ends[i], elt_end) - elt_start

                ax1.add_line(matplotlib.lines.Line2D([readline_start, readline_end], [read.ypos, read.ypos], lw=lw, zorder=2, color=sample_color[bamname], alpha=la))

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_start):
                h_e = h_end[h]
                h_color = h_colors[h]
                ymin = min(ax1.get_ybound())
                yheight = max(ax1.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax1.add_patch(h)

    # plot correspondence between genome space and cpg space

    logger.info('building mod-space plot...')
    ax2 = plt.subplot(gs[2])
    ax3 = ax2.twiny()

    ax2.set_ylim(0,10)
    ax2.set_yticklabels([])

    x1 = []
    x2 = []

    step = int(args.modspace)

    orig_positions = list(set(meth_table['orig_loc']))

    for i, x in enumerate(orig_positions):
        if i in (0, len(orig_positions)-1):
            x2.append(x)
            x1.append(coord_to_cpg[x])

        elif i % step == 0:
            x2.append(x)
            x1.append(coord_to_cpg[x])

    ax2.vlines(x1, 0, 1, color='#777777', zorder=1)
    ax3.vlines(x2, 9, 10, color='#777777', zorder=1)

    if args.highlight or args.highlight_bed:
        for i in range(len(h_start)):
            orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), h_end[i]-h_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=2)
            cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), h_cpg_end[i]-h_cpg_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=3)

            ax3.add_patch(orig_highlight_box)
            ax2.add_patch(cpg_highlight_box)

    for x1_x, x2_x in zip(x1, x2):
        link_end1 = (x1_x, 1)
        link_end2 = (x2_x, 9)

        l_col = '#777777'

        for i in range(len(h_start)):
            if x2_x >= h_start[i] and x2_x <= h_end[i]:
                l_col = h_colors[i]

        con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=l_col)
        ax3.add_artist(con)

    ax0.set_xlim(ax3.get_xlim()) # sync axes between orig coords and gtf plot
    ax1.set_xlim(ax3.get_xlim())
    ax2.set_xticks([])
    ax3.set_xticks([])

    n_ticks = int(args.nticks) + 1
    tick_interval = (elt_end-elt_start)/n_ticks
    tick_list = list(range(int(ax0.get_xlim()[0]), int(ax0.get_xlim()[1]), int(tick_interval)))

    revised_tick_list = []
    for t in tick_list:
        if t < 0:
            revised_tick_list.append(0)
        else:
            revised_tick_list.append(t)

    tick_list = sorted(list(set(revised_tick_list)))

    xt_labels = [str(int(t+elt_start)) for t in tick_list]
    xt_labels[0] = chrom

    ax0.set_xticks(tick_list)

    if n_ticks > 11:
        ax0.set_xticklabels(xt_labels, rotation=45)
    else:
        ax0.set_xticklabels(xt_labels)

    # llr plot

    logger.info('building llr plot...')
    ax4 = plt.subplot(gs[3])
    ax4.set_xticks([])
    ax4.yaxis.tick_right()

    ax4.axhline(y=0, c='#bbbbbb', linestyle='--',lw=1)

    for mod in use_mods:
        upper = 1.0
        lower = 0.0

        if not methbam:
            upper, lower = get_cutoffs(list(data.values())[0][0], mod)

        ax4.axhline(y=upper, c='k', linestyle='--',lw=1)
        ax4.axhline(y=lower, c='k', linestyle='--',lw=1)

    ax4 = sns.lineplot(x='loc', y='modstat', hue='sample', data=meth_table, palette=sample_color, zorder=2)

    if args.statname:
        ax4.set_ylabel(args.statname)

    ax4.set_xlim(ax2.get_xlim())

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_cpg_start):
                h_e = h_cpg_end[h]
                h_color = h_colors[h]
                ymin = min(ax4.get_ybound())
                yheight = max(ax4.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax4.add_patch(h)

    # meth frac plot

    logger.info('building meth frac plot...')
    ax5 = plt.subplot(gs[4])

    order_stack = 2

    smoothalpha = float(args.smoothalpha)

    if smoothalpha > 1.0 or smoothalpha < 0.0:
        logger.warning('--smoothalpha must be between 0 and 1, set to 1.0')
        smoothalpha = 1.0

    for sample in sample_order:
        windowed_methfrac, meth_n = slide_window(meth_table, sample, width=int(args.slidingwindowsize), slide=int(args.slidingwindowstep))

        smoothed_methfrac = smooth(np.asarray(list(windowed_methfrac.values())), window_len=int(args.smoothwindowsize), window=args.smoothfunc)

        masked_segs = mask_methfrac(list(meth_n.values()), cutoff=args.maskcutoff)

        frac_masked = len(list(itertools.chain(*masked_segs))) / len(meth_n.values())

        logger.info('%s:%d-%d (%s), sample %s fraction masked: %.3f' % (chrom, elt_start, elt_end, ''.join(use_mods), sample, frac_masked))

        if frac_masked > float(args.maxmaskedfrac):
            logger.warning('%s:%d-%d (%s), skip sample %s due to --maxmaskedfrac %.3f' % (chrom, elt_start, elt_end, ''.join(use_mods), sample, float(args.maxmaskedfrac)))
            continue

        ax5.plot(list(windowed_methfrac.keys()), smoothed_methfrac, marker='', lw=4, color=sample_color[sample], alpha=smoothalpha)

        order_stack += 1

        if not args.nomask:
            for seg in masked_segs:
                if len(seg) > 2:
                    mf_seg = np.asarray(smoothed_methfrac)[seg]
                    pos_seg = np.asarray(list(windowed_methfrac.keys()))[seg]
                
                    ax5.plot(pos_seg, mf_seg, marker='', lw=4, color='#ffffff', alpha=0.8, zorder=order_stack)

                    order_stack += 1

    ax5.set_xlim(ax2.get_xlim())
    ax5.set_ylim((-0.05,1.05))

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_cpg_start):
                h_e = h_cpg_end[h]
                h_color = h_colors[h]
                ymin = min(ax5.get_ybound())
                yheight = max(ax5.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax5.add_patch(h)

    fn_prefix = '.%s_%d_%d.%s' % (chrom, elt_start, elt_end, ''.join(use_mods))

    if methbam:
        fn_prefix = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1]) + fn_prefix
        if len(args.bams.split(',')) > 1:
            fn_prefix += '.cohort'
    else:
        fn_prefix = '.'.join(os.path.basename(args.data).split('.')[:-1]) + fn_prefix

    if args.genes is not None:
        fn_prefix = '_'.join(args.genes.split(',')) + '.' + fn_prefix

    if args.phased:
        fn_prefix += '.phased'

    param_str = '.ms%d.smw%d' % (args.modspace, int(args.smoothwindowsize))

    fn_prefix += param_str

    if args.max_read_density:
        fn_prefix += '.mrd%.2f' % float(args.max_read_density)

    if args.mincalls > 0:
        fn_prefix += '.mc%d' % args.mincalls
    
    if args.svg:
        plt.savefig('%s.locus.meth.svg' % fn_prefix, bbox_inches='tight')
        logger.info('plot saved to %s.locus.meth.svg' % fn_prefix)

    else:
        plt.savefig('%s.locus.meth.png' % fn_prefix, bbox_inches='tight')
        logger.info('plot saved to %s.locus.meth.png' % fn_prefix)


def region(args):
    '''
    plotting function for windowed view of larger regions
    '''

    assert ':' in args.interval
    assert '-' in args.interval

    chrom, pos = args.interval.split(':')
    start, end = pos.split('-')

    start = int(start.replace(',',''))
    end = int(end.replace(',',''))

    assert start < end

    if end-start < 500000:
        logger.warning('locus smaller than 0.5 Mbp, "methylartist locus" may yield better results')

    ref = pysam.FastaFile(args.ref)
    motif = args.norm_motif.upper()
    region_seq = ref.fetch(chrom, start, end).upper()
    motif_count = region_seq.count(motif) + region_seq.count(rc(motif))

    if args.windows is None:
        args.windows = round(motif_count / 30)
        logger.info('set window count to %d' % args.windows)

    args.windows = int(args.windows)

    if args.windows < 500:
        args.windows = 500
        logger.info('resetting windows to a minimum of 500')

    motifs_per_window = int(motif_count / int(args.windows))

    if motifs_per_window == 0:
        motifs_per_window = 1

    logger.info('motif count: %d, per window: %d' % (motif_count, motifs_per_window))

    w_starts = [start]
    w_ends = []

    i = 0
    motif_count = 0

    while i < len(region_seq) - len(motif):
        site_fwd = region_seq[i:i+len(motif)]
        site_rev = rc(site_fwd)

        if motif in (site_fwd, site_rev):
            motif_count += 1

        i += 1

        if motif_count == motifs_per_window:
            w_starts.append(start+i)
            w_ends.append(start+i)
            motif_count = 0

    w_ends.append(end)

    assert len(w_starts) == len(w_ends)
    logger.info('using %d windows normalised for %s content' % (len(w_starts), args.norm_motif))

    if args.smoothwindowsize is None:
        w = len(w_starts)
        args.smoothwindowsize = round(0.02*w + 4)
        if args.smoothwindowsize % 2 != 0:
            args.smoothwindowsize += 1
            
        logger.info('set --smoothwindowsize to %d' % args.smoothwindowsize)

    args.smoothwindowsize = int(args.smoothwindowsize)

    if args.modspace is None:
        args.modspace = round(len(w_starts)/300)
        if args.modspace == 0:
            args.modspace = 1

        logger.info('set --modspace to %d' % args.modspace)

    args.modspace = int(args.modspace)

    data = dd(list)
    mods = []
    user_colours = {}

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams')

    methbam = False

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both')
        with open(args.data) as _:
            for line in _:
                c = line.strip().split()
                if len(c) < 2:
                    logger.warning("required fields for -d/--data are: .bam file and methylation .db (generated by methylartist)")
                    sys.exit()

                bam, meth_db = c[:2]
                for m_db in meth_db.split(','):
                    data[bam].append(m_db)
                    mods += sorted(get_modnames(m_db))

                if len(c) == 3:
                    user_colours[bam] = c[2]

    if args.bams is not None:
        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')

        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    bams.append(line.strip())

        for bam in bams:
            if ':' in bam:
                bam, ucol = bam.split(':')
                user_colours[bam] = ucol

            if not os.path.exists(bam+'.bai'):
                sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

        mods = mods_methbam(bam)

    mods = list(set(mods))

    logger.info('found mods: %s' % ','.join(mods))

    if args.mods:
        for mod in args.mods.split(','):
            assert mod in mods, 'mod %s not found' % mod
        mods = args.mods.split(',')

    logger.info('using mods %s' % ','.join(mods))

    reads = {}
    orig_bam = {}

    phases = [None]
    if args.phased:
        phases = ['1','2']

    for phase in phases:
        for bam, meth_dbs in data.items():
            for mod in mods:
                bamname = '.'.join(os.path.basename(bam).split('.')[:-1]) + '.' + mod

                if args.phased:
                    bamname += '.' + phase

                orig_bam[bamname] = bam
                reads[bamname] = get_meth_locus(args, bam, meth_dbs, mod, phase=phase, methbam=methbam)

    pool = mp.Pool(processes=int(args.procs))

    meth_segs = dd(dict)
    sample_names = {}
    shallow_windows = dd(list)
    min_window_calls = int(args.min_window_calls)

    for phase in phases:
        results = []

        for seg_start, seg_end in zip(w_starts, w_ends):
            for bam_fn, meth_dbs in data.items():
                seg_strand = '.'
                seg_name = '.'.join(os.path.basename(bam_fn).split('.')[:-1])
                res = pool.apply_async(get_segmeth_calls, [args, bam_fn, mods, meth_dbs, chrom, seg_start, seg_end, seg_name, seg_strand, phase, methbam])
                results.append(res)

        logger.info('parsing segments (phase %s)...' % str(phase))

        for res in results:
            meth_result, seg = res.get()
            if meth_result is None:
                continue

            seg_id = '%s:%d-%d' % seg[:3]

            seg_chrom, seg_start, seg_end, seg_name, seg_strand = map(str, seg[:-1])

            meth_segs[seg_id]['seg_id']     = seg_id
            meth_segs[seg_id]['seg_chrom']  = seg_chrom
            meth_segs[seg_id]['seg_start']  = seg_start
            meth_segs[seg_id]['seg_end']    = seg_end
            meth_segs[seg_id]['seg_name']   = seg_name
            meth_segs[seg_id]['seg_strand'] = seg_strand

            for modname, meth_data in meth_result.items():
                no_calls = 0
                meth_calls = 0
                unmeth_calls = 0

                if -1 in meth_data:
                    unmeth_calls = meth_data[-1]

                if 0 in meth_data:
                    no_calls = meth_data[0]

                if 1 in meth_data:
                    meth_calls = meth_data[1]

                if meth_calls + unmeth_calls < min_window_calls:
                    if seg_id not in shallow_windows[modname]:
                        shallow_windows[modname].append(seg_id) # TODO phase aware

                sample_name = seg_name + '.' + modname
                
                if args.phased:
                    sample_name += '.' + phase

                sample_names[sample_name] = True

                meth_segs[seg_id][sample_name + '_meth_calls'] = meth_calls
                meth_segs[seg_id][sample_name + '_unmeth_calls'] = unmeth_calls

                if meth_calls + unmeth_calls == 0:
                    meth_segs[seg_id][sample_name + '_frac'] = 0 # changed from NaN
                else:
                    meth_segs[seg_id][sample_name + '_frac'] = meth_calls/float(meth_calls+unmeth_calls)

    for mod in mods: # TODO phase aware
        shallow_frac = len(shallow_windows[mod])/len(w_starts)*100.0
        if shallow_frac > 0.0:
            logger.info('%.2f percent of windows for mod %s had less than %d calls' % (shallow_frac, mod, min_window_calls))
            if shallow_frac > float(args.maxuncovered):
                sys.exit('greater than %.2f windows are uncovered, aborting.' % float(args.maxuncovered))

    deleted_segs = 0
    for mod in mods: # TODO phase aware
        for seg_id in shallow_windows[mod]:
            if seg_id in meth_segs:
                del meth_segs[seg_id]
                deleted_segs += 1

    if deleted_segs > 0:
        logger.info('removed %d segs with less than %d calls in at least one mod' % (deleted_segs, min_window_calls))

    meth_segs = pd.DataFrame.from_dict(meth_segs).T

    if 'seg_start' not in meth_segs:
        logger.warning('no methylation calls.')
        sys.exit()

    meth_segs['seg_start'] = pd.to_numeric(meth_segs['seg_start'])

    meth_segs.sort_values('seg_start', inplace=True)
    meth_segs['pos'] = np.arange(len(meth_segs.index))

    for sample in sample_names:
        meth_segs[sample] = smooth(np.asarray(meth_segs[sample + '_frac']), window_len=int(args.smoothwindowsize), window=args.smoothfunc)
        meth_segs[sample] = meth_segs[sample].rolling(window=10, min_periods=1, center=True).mean()

    coord_to_pos = {}
    for orig_loc, new_loc in zip(meth_segs['seg_start'], meth_segs['pos']):
        coord_to_pos[orig_loc] = new_loc

    # highlights

    h_start = []
    h_end = []
    h_cpg_start = []
    h_cpg_end = []
    h_colors = []

    if args.highlight:
        h_colors = sns.color_palette(args.highlightpalette, n_colors=len(args.highlight.split(',')))

    if args.highlight:
        for h in args.highlight.split(','):
            if ':' in h:
                h = h.split(':')[-1]
                
            h_s, h_e = map(int, h.split('-'))
            h_start.append(h_s)
            h_end.append(h_e)

            h_cpg_start.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_start[-1]))])
            h_cpg_end.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_end[-1]))])

    if args.highlight_bed:
        need_colour = 0

        with open(args.highlight_bed) as h_bed:
            for line in h_bed:
                c = line.strip().split()
                assert len(c) >= 3, 'malformed line in --highlight_bed: %s' % line.strip()
                assert c[0] == chrom, 'all entries in --highlight_bed must be on the same chromosome as -i/--interval'

                h_s, h_e = map(int, c[1:3])

                colour = None

                if len(c) > 3:
                    colour = c[3]
                else:
                    need_colour += 1
                
                h_start.append(h_s)
                h_end.append(h_e)

                h_cpg_start.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_start[-1]))])
                h_cpg_end.append(coord_to_pos[min(meth_segs['seg_start'], key=lambda x:abs(x-h_end[-1]))])

                h_colors.append(colour)

        more_colours = sns.color_palette(args.highlightpalette, n_colors=need_colour)

        for i, c in enumerate(h_colors):
            if c is None:
                h_colors[i] = more_colours.pop()

    # mask

    readmask = []
    if args.readmask:
        for ivl in args.readmask.split(','):
            if ':' in ivl:
                ivl = ivl.split(':')[1]
            assert '-' in ivl, 'malformed --readmask interval(s): %s' % args.readmask

            readmask.append(list(map(int, ivl.split('-'))))

    # set up plot

    if end-start > 500000:
        logger.info('region size %s greater than 0.5 Mbp, setting --skip_align_plot True' % (end-start))
        args.skip_align_plot = True
    
    if args.force_align_plot == args.skip_align_plot == True:
        logger.info('--skip_align_plot overridden via --force_align_plot')
        args.skip_align_plot = False

    fig = plt.figure()
    height_ratios = [1,5,1,3]

    if args.skip_align_plot:
        height_ratios = [1,0,1,4]

    gs = gridspec.GridSpec(4,1,height_ratios=height_ratios, hspace=0)

    img_w = 16
    img_h = 8

    if args.skip_align_plot:
        if '--height' not in sys.argv[0]:
            args.height = 4.5

    if args.width:
        img_w = float(args.width)

    if args.height:
        img_h = float(args.height)

    fig.set_size_inches(img_w, img_h)

    if args.panelratios:
        p_ratios = list(map(int, args.panelratios.split(',')))
        gs = gridspec.GridSpec(4,1,height_ratios=p_ratios, hspace=0)

    sample_order = sorted(sample_names.keys())

    logger.info('sample order: %s' % ','.join(sample_order))

    sample_color = {}
    for i, sample in enumerate(sample_order):
        sample_color[sample] = sns.color_palette(args.samplepalette, n_colors=len(sample_order))[i]

    for sample in sample_color:
        if orig_bam[sample] in user_colours:
            sample_color[sample] = user_colours[orig_bam[sample]]

    if args.color_by_hp:
        if args.phased:
            hp_color = {}
            for hp in ('1','2'):
                hp_color[hp] = sns.color_palette(args.samplepalette, n_colors=2)[int(hp)-1]
            
            for sample in sample_color:
                hp = sample.split('.')[-1]
                assert hp in ('1','2')
                sample_color[sample] = hp_color[hp]

        else:
            logger.warning('--color_by_hp has no effect without --phase')

    # plot genes

    ax0 = plt.subplot(gs[0])

    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.spines['right'].set_visible(False)
    ax0.xaxis.set_ticks_position('top')

    gtf = None
    genes = [] 

    if args.gtf is not None:
        logger.info('building genes plot...')
        gtf = pysam.Tabixfile(args.gtf)
        genes = build_genes(gtf, chrom, start, end)

    exon_patches = []
    tx_lines = []

    genes_of_interest = []

    if args.genes is not None:
        genes_of_interest = args.genes.strip().split(',')

    genemap = dd(Intersecter)

    if genes_of_interest:
        new_genes = {}
        for ensg in genes:
            if genes[ensg].name in genes_of_interest:
                new_genes[ensg] = genes[ensg]
        
        genes = new_genes

    if args.gtf:
        logger.info('%d genes in region' % len(genes))

    gene_colours = sns.color_palette(args.genepalette, n_colors=len(genes))

    for i, ensg in enumerate(genes):
        if not genes[ensg].has_tx():
            continue

        y = 1

        while genemap[y].find(genes[ensg].tx_start, genes[ensg].tx_end):
            y += 1
            if args.gene_track_height:
                if y > int(args.gene_track_height):
                    y = int(args.gene_track_height)
                    break

        tx_lines.append(matplotlib.lines.Line2D([genes[ensg].tx_start, genes[ensg].tx_end], [0.4+y, 0.4+y], color=gene_colours[i], zorder=2))

        genes[ensg].merge_exons()
        for exon_start, exon_end in genes[ensg].exons:
            exon_len = exon_end - exon_start
            exon_patches.append(matplotlib.patches.Rectangle([exon_start, y], exon_len, 0.8, edgecolor=gene_colours[i], facecolor=gene_colours[i], zorder=3))

        genemap[y].add_interval(Interval(genes[ensg].tx_start, genes[ensg].tx_end))

        if args.labelgenes:
            lg_x  = max(genes[ensg].tx_start, start)
            gtxt  = ax0.text(lg_x, y+0.8, genes[ensg].name, bbox=dict(boxstyle="round,pad=0.3", fc="lavender", alpha=0.5, lw=0), zorder=4)
            bb_w  = gtxt.get_tightbbox(renderer=fig.canvas.get_renderer()).width
            fig_w = fig.get_size_inches()[0]*fig.dpi
            txt_w = bb_w/fig_w*(end-start)
            gtxt.set_x(lg_x-txt_w/2)

    gene_height = max(3, len(genemap))

    if args.labelgenes:
        gene_height = max(6, len(genemap))

    ax0.set_ylim(0, gene_height)
    ax0.set_yticks([])

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_start):
                h_e = h_end[h]
                h_color = h_colors[h]
                ymin = min(ax0.get_ybound())
                yheight = max(ax0.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax0.add_patch(h)
        
    for p in exon_patches:
        ax0.add_patch(p)

    for tx in tx_lines:
        ax0.add_line(tx)

    # per-read plot

    ax1 = plt.subplot(gs[1])
    ax1.set_xticks([])
    ax1.set_yticks([])

    if args.skip_align_plot:
        logger.info('skipped alignment plot due to --skip_align_plot')

    else:
        logger.info('building read alignment plot...')

        readstack = dd(list)

        max_y  = 1
        pack_y = 1

        masked_count = 0

        for bamname in reads:
            fetch_reads_bam = pysam.AlignmentFile(orig_bam[bamname])
            pos_cache = {}

            for read in fetch_reads_bam.fetch(chrom, start, end):
                if read.mapq < int(args.min_mapq):
                    continue

                if read.is_supplementary or read.is_secondary or read.is_duplicate:
                    if not args.allreads:
                        continue

                masked = False
                if len(readmask) > 0:
                    for mask_start, mask_end in readmask:
                        if read.reference_start >= mask_start and read.reference_end <= mask_end:
                            logger.debug('masked read: %s' % read.query_name)
                            masked_count += 1
                            masked = True
                
                if masked:
                    continue

                if read.query_name not in pos_cache:
                    pos_cache[read.query_name] = ([read.reference_start], [read.reference_end])
                else:
                    pos_cache[read.query_name][0].append(read.reference_start)
                    pos_cache[read.query_name][1].append(read.reference_end)

            for readname, read in reads[bamname].items():
                if readname not in pos_cache:
                    logger.debug('read %s not found in %s (skipped)' % (readname, bamname))
                    continue

                if read.call_count == 0:
                    continue

                read.ypos = max_y
                max_y += 1

                read.starts, read.ends = pos_cache[readname]
                readstack[bamname].append(read)

            fetch_reads_bam.close()

        if args.readmask:
            logger.info('masked %d reads due to --readmask' % masked_count)

        # read packing

        for bamname in readstack:
            reads = readstack[bamname]

            y = dd(list)

            for read in readstack[bamname]:
                y[read.ypos].append(read)

            for p in y:
                for q in y:
                    if p == q:
                        continue

                    for read_q in y[q]:
                            move = True

                            for read_p in y[p]:
                                if read_q.overlap(read_p):
                                    move = False

                            if move:
                                y[p].append(read_q)
                                y[q].remove(read_q)

            for p in y:
                if len(y[p]) > 0:
                    for read in sorted(y[p], key=lambda r: min(r.starts)):
                        read.ypos = pack_y
                    pack_y += 1

        ax1.set_ylim(0,pack_y+1)

        for bamname in readstack:
            for read in readstack[bamname]:
                for call_pos, call in read.meth_calls.items():
                    call_pos += start

                    if call == -1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec=sample_color[bamname], mfc='white', markersize=2, zorder=3)

                    if call == 1:
                        ax1.plot(call_pos, read.ypos, marker='o', fillstyle='full', mec='black', mfc='black', markersize=2, zorder=3)

                for i in range(len(read.starts)):
                    readline_start = max(read.starts[i], start)
                    readline_end   = min(read.ends[i], end)

                    ax1.add_line(matplotlib.lines.Line2D([readline_start, readline_end], [read.ypos, read.ypos], zorder=2, color=sample_color[bamname], alpha=0.4))

        highlight_patches = []

        if args.highlight_panels:
            a = float(args.highlight_panels)
            if args.highlight or args.highlight_bed:
                for h, h_s in enumerate(h_start):
                    h_e = h_end[h]
                    h_color = h_colors[h]
                    ymin = min(ax1.get_ybound())
                    yheight = max(ax1.get_ybound())-ymin
                    highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

        for h in highlight_patches:
            ax1.add_patch(h)

    # plot correspondence between genome space and cpg space

    logger.info('building mod-space plot...')
    ax2 = plt.subplot(gs[2])
    ax3 = ax2.twiny()

    ax2.set_ylim(0,10)
    ax2.set_yticklabels([])

    x1 = []
    x2 = []

    step = int(args.modspace)

    for i, x in enumerate(meth_segs['seg_start']):
        if i in (0, len(meth_segs['seg_start'])-1):
            x2.append(x)
            x1.append(coord_to_pos[x])

        elif i % step == 0:
            x2.append(x)
            x1.append(coord_to_pos[x])

    
    ax2.vlines(x1, 0, 1, color='#777777', zorder=1)
    ax3.vlines(x2, 9, 10, color='#777777', zorder=1)

    if args.highlight or args.highlight_bed:
        for i in range(len(h_start)):
            orig_highlight_box = matplotlib.patches.Rectangle((h_start[i],9), h_end[i]-h_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=2)
            cpg_highlight_box = matplotlib.patches.Rectangle((h_cpg_start[i],0), h_cpg_end[i]-h_cpg_start[i], 1.0, lw=1, edgecolor='#777777', facecolor=h_colors[i], zorder=3)

            ax3.add_patch(orig_highlight_box)
            ax2.add_patch(cpg_highlight_box)

    for x1_x, x2_x in zip(x1, x2):
        link_end1 = (x1_x, 1)
        link_end2 = (x2_x, 9)

        l_col = '#777777'

        for i in range(len(h_start)):
            if x2_x >= h_start[i] and x2_x <= h_end[i]:
                l_col = h_colors[i]

        con = ConnectionPatch(xyA=link_end1, xyB=link_end2, coordsA="data", coordsB="data", axesA=ax2, axesB=ax3, color=l_col)
        ax3.add_artist(con)

    ax0.set_xlim(ax3.get_xlim()) # sync axes between orig coords and gtf plot
    ax1.set_xlim(ax3.get_xlim())
    ax2.set_xticks([])
    ax3.set_xticks([])

    n_ticks = int(args.nticks) + 1
    tick_interval = (end-start)/n_ticks
    tick_list = list(range(int(ax0.get_xlim()[0]), int(ax0.get_xlim()[1]), int(tick_interval)))

    revised_tick_list = []
    for t in tick_list:
        if t < 0:
            revised_tick_list.append(0)
        else:
            revised_tick_list.append(t)

    tick_list = sorted(list(set(revised_tick_list)))

    xt_labels = [str(int(t)) for t in tick_list]
    xt_labels[0] = chrom

    ax0.set_xticks(tick_list)
    if n_ticks > 11:
        ax0.set_xticklabels(xt_labels, rotation=45)
    else:
        ax0.set_xticklabels(xt_labels)

    # meth frac plot

    logger.info('building meth frac plot...')
    ax4 = plt.subplot(gs[3])

    smoothalpha = float(args.smoothalpha)

    if smoothalpha > 1.0 or smoothalpha < 0.0:
        logger.warning('--smoothalpha must be between 0 and 1, set to 1.0')
        smoothalpha = 1.0

    for sample in sample_order:
        ax4.plot(meth_segs['pos'], meth_segs[sample], marker='', lw=4, color=sample_color[sample], zorder=2, label=sample, alpha=smoothalpha)

    ax4.legend()
    ax4.set_xlim(ax2.get_xlim())
    ax4.set_ylim((-0.05,1.05))

    highlight_patches = []

    if args.highlight_panels:
        a = float(args.highlight_panels)
        if args.highlight or args.highlight_bed:
            for h, h_s in enumerate(h_cpg_start):
                h_e = h_cpg_end[h]
                h_color = h_colors[h]
                ymin = min(ax4.get_ybound())
                yheight = max(ax4.get_ybound())-ymin
                highlight_patches.append(matplotlib.patches.Rectangle([h_s, ymin], h_e-h_s, yheight, edgecolor=None, facecolor=h_color, alpha=a, zorder=1))

    for h in highlight_patches:
        ax4.add_patch(h)

    # output

    fn_prefix = '.%s_%d_%d.%s' % (chrom, start, end, ''.join(mods))

    if methbam:
        fn_prefix = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1]) + fn_prefix
        if len(args.bams.split(',')) > 1:
            fn_prefix += '.cohort'
    else:
        fn_prefix = '.'.join(os.path.basename(args.data).split('.')[:-1]) + fn_prefix

    if args.genes is not None:
        fn_prefix = '_'.join(args.genes.split(',')) + '.' + fn_prefix

    fn_prefix += '.s%d.w%d.m%d' % (args.smoothwindowsize, args.windows, args.modspace)

    if args.phased:
        fn_prefix += '.phased'

    if args.svg:
        plt.savefig('%s.region.meth.svg' % fn_prefix, bbox_inches='tight')
        logger.info('plot saved to %s.region.meth.svg' % fn_prefix)

    else:
        plt.savefig('%s.region.meth.png' % fn_prefix, bbox_inches='tight')
        logger.info('plot saved to %s.region.meth.png' % fn_prefix)


def composite(args):
    '''
    plot composite methylation profiles relative to a consensus element
    '''

    if not skbio_installed:
        sys.exit('scikit-bio is not installed but is required for this function. Please install e.g. via "pip install scikit-bio" or "conda install -c https://conda.anaconda.org/biocore scikit-bio"')

    te_ref_seq = single_seq_fa(args.teref).upper()

    assert os.path.exists(args.ref + '.fai'), 'ref fasta must be indexed'

    mod_names = []

    data = dd(list)

    if args.data is None and args.bams is None:
        sys.exit('please specify either -d/--data or -b/--bams')

    if args.color_by_phase and not args.phased:
        sys.exit('must specify --phased to use --color_by_phase')

    methbam = False

    if args.data is not None:
        if args.bams is not None:
            sys.exit('please specify either -d/--data or -b/--bams but not both')

        with open(args.data) as _:
            for line in _:
                bam, meth = line.strip().split()
                if ',' in meth:
                    sys.exit('multiple .db files per bam not supported for composite')

                data[bam] = meth

                for m in get_modnames(meth):
                    mod_names.append(m)

    if args.bams is not None:
        methbam = True
        bams = []

        if args.bams.endswith('.bam') or (':' in args.bams and args.bams.split(':')[0].endswith('.bam')):
            bams = args.bams.split(',')

        else:
            logger.info('assuming %s contains a list of .bams' % args.bams)
            with open(args.bams) as bam_list:
                for line in bam_list:
                    bams.append(line.strip())

        for bam in bams:
            if ':' in bam:
                bam, _ = bam.split(':')

            if not os.path.exists(bam+'.bai'):
                sys.exit('bam not indexed: %s' % bam)

            data[bam] = None

        mod_names = mods_methbam(bam)

    mod_names = list(set(mod_names))

    logger.info('found mod names: %s' % ','.join(mod_names))

    use_mod = None

    if len(mod_names) == 1:
        use_mod = mod_names[0]
        logger.info('using the one available mod: %s' % use_mod)

    if use_mod is None:
        if args.mod is None:
            sys.exit('please specify a modification via --mod\navailable mods are: %s' % ','.join(mod_names))
            
        use_mod = args.mod

        if use_mod not in mod_names:
            sys.exit('please specify a modification via --mod\navailable mods are: %s' % ','.join(mod_names))

    data_basename = None

    if methbam:
        data_basename = '.'.join(os.path.basename(args.bams.split(',')[0]).split('.')[:-1])
        if len(args.bams.split(',')) > 1:
            data_basename += '.cohort'

    else:
        data_basename = '.'.join(os.path.basename(args.data).split('.')[:-1])    

    seg_basename = '.'.join(os.path.basename(args.segdata).split('.')[:-1])

    outfn = data_basename + '.' + seg_basename

    if args.phased:
        outfn += '.phased'

    outfn += '.composite'

    if args.svg:
        outfn += '.svg'
    else:
        outfn += '.png'

    logger.info('composite output filename: %s' % outfn)

    pool = mp.Pool(processes=int(args.procs))

    results = []

    with open(args.segdata) as bed:
        for line in bed:
            c = line.strip().split()
            
            seg_chrom, seg_start, seg_end = line.strip().split()[:3]        

            str_col = 0

            if c[3] in ('-', '+'):
                str_col = 3
            
            elif c[4] in ('-', '+'):
                str_col = 4
            
            else:
                sys.exit('strand (+/-) not found in cols 4 or 5 of %s' % args.segdata)
            
            seg_strand = c[str_col]

            seg_start  = int(seg_start)
            seg_end    = int(seg_end)

            if args.phased:
                for phase in ('1','2'):
                    res = pool.apply_async(get_meth_profile_composite, [args, data, methbam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, phase])
                    results.append(res)

            else:
                res = pool.apply_async(get_meth_profile_composite, [args, data, methbam, seg_chrom, seg_start, seg_end, seg_strand, use_mod, None])
                results.append(res)

    # collect mod data
    out_res = dd(list) # cache for --outelts

    for res in results:
        per_bam_res = res.get()

        for bam in per_bam_res:
            if per_bam_res[bam] is None:
                continue

            coord_meth_pos, meth_profile = per_bam_res[bam]

            if len(coord_meth_pos) == 0:
                continue

            out_res[bam].append((coord_meth_pos, meth_profile))

    # set bounds
    mod_start = 0
    mod_end = len(te_ref_seq)

    if args.start:
        mod_start = int(args.start)

    if args.end:
        mod_end = int(args.end)

    assert mod_start < mod_end

    # set up plot

    sample_color = {}
    if args.color_by_phase:
        for i, phase in enumerate(('phase1', 'phase2')):
            sample_color[phase] = sns.color_palette(args.palette, n_colors=2)[i]

    else:
        for i, bam in enumerate(out_res):
            sample_color[bam] = sns.color_palette(args.palette, n_colors=len(out_res))[i]

    fig = plt.figure()
    gs = None

    gs = gridspec.GridSpec(3,1,height_ratios=[3,1,8])

    g = 0

    # mean

    ax0 = plt.subplot(gs[g])
    g += 1
    ax0.set_ylim((-0.05,1.05))
    ax0.set_xlim((mod_start, mod_end))
    ax0.set_xticks([])

    meanplot_table = dd(dict)

    for bam in out_res:

        meth_by_coord = dd(list)

        per_elt_calls = []

        for coord_meth_pos, meth_profile in out_res[bam]:
            per_elt_calls.append(len(coord_meth_pos))
            for c, m in zip(coord_meth_pos, meth_profile):
                meth_by_coord[c].append(m)

        median_call_count = int(np.median(per_elt_calls))

        v = sorted([len(c) for c in meth_by_coord.values()], reverse=True)

        cutoff = v[-1]

        if len(v) > median_call_count:
            cutoff = v[median_call_count]

        logger.info('%s: median per element call count: %d' % (bam, median_call_count))
        logger.info('%s: per site call count cutoff: %d' % (bam, cutoff))

        for c in sorted(meth_by_coord.keys()):
            if len(meth_by_coord[c]) >= cutoff:
                for m in meth_by_coord[c]:
                    u = str(uuid4())
                    if args.color_by_phase:
                        meanplot_table[u]['sample'] = bam.split('.')[-1]
                    else:    
                        meanplot_table[u]['sample'] = bam

                    meanplot_table[u]['coord'] = c
                    meanplot_table[u]['meth'] = m
    
    meanplot_table = pd.DataFrame.from_dict(meanplot_table).T

    ax0 = sns.lineplot(x='coord', y='meth', data=meanplot_table, ci='sd', lw=2, hue='sample', palette=sample_color)
    ax0.set_ylabel(args.meanplot_ylabel)

    # mod
    ax1 = plt.subplot(gs[g])
    g += 1
    ax1.set_xlim((mod_start, mod_end))

    box = matplotlib.patches.Rectangle([0, 0], mod_end-mod_start, 1.0, edgecolor='#555555', facecolor='#cfcfcf', zorder=1)
    ax1.add_patch(box)

    if args.blocks:
        with open(args.blocks) as blocks:
            for line in blocks:
                b_start, b_end, b_name, b_col = line.strip().split()
                b_start = int(b_start)
                b_end = int(b_end)

                box = matplotlib.patches.Rectangle([b_start, 0], b_end-b_start, 1.0, edgecolor='#555555', facecolor=b_col, zorder=2)
                ax1.add_patch(box)

    mod_locs = []

    motif = args.motif

    for i in range(len(te_ref_seq)-len(motif)):
        if i >= mod_start and i <= mod_end:
            if te_ref_seq[i:i+len(motif)] == motif:
                mod_locs.append(i)

    ax1.vlines(mod_locs, 0, 1, lw=1, colors=('#FF4500'), zorder=3, alpha=0.5)

    ax1.spines['bottom'].set_visible(False)
    ax1.spines['left'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.set_xlim((mod_start, mod_end))
    ax1.xaxis.set_ticks_position('top')

    # wiggles
    ax2 = plt.subplot(gs[g])
    g += 1

    for bam in out_res:
        for coord_meth_pos, meth_profile in random.sample(out_res[bam], int(args.outelts)):
            if args.color_by_phase:
                bam = bam.split('.')[-1]

            ax2.plot(coord_meth_pos, meth_profile, lw=float(args.linewidth), alpha=float(args.alpha), color=sample_color[bam])

    ax2.set_ylim((-0.05,1.05))
    ax2.set_xlim((mod_start, mod_end))
    ax2.set_xlabel('position')

    fig.set_size_inches(16, 6)

    plt.savefig(outfn, bbox_inches='tight')
    logger.info('plotted to %s' % outfn)


def wgmeth(args):
    '''
    generates whole genome output in DSS or bedmethyl format
    '''

    meth_table = [None, None]
    seg_meth_table_store = dd(list)

    pool = mp.Pool(processes=int(args.procs))
    bin_size = int(args.binsize)

    methbam = False

    if args.methdb is None:
        methbam = True

    if methbam:
        mods = sorted(mods_methbam(args.bam))
    else:
        mods = sorted(get_modnames(args.methdb))

    if methbam and len(mods) == 0:
        sys.exit('bam %s does not appear to contain Mm/Ml tags' % args.bam)

    if args.mod not in mods:
        logger.warning('mod %s not in known mods for db: %s' % (args.mod, ','.join(mods)))
        sys.exit()

    if len(mods) > 1 and args.mod is None:
        logger.warning('more than one mod exists, need to pick one with -m/--mod: %s' % ','.join(mods))
        sys.exit()

    results = []

    with open(args.fai) as fai:
        for line in fai:
            chrom, chrlen = line.strip().split()[:2]
            chrlen = int(chrlen)

            if args.chrom:
                if args.chrom != chrom:
                    continue

            for seg_start in range(0, chrlen, bin_size):
                seg_end = seg_start + bin_size

                if seg_end > chrlen:
                    seg_end = chrlen

                seg_start = int(seg_start)
                seg_end = int(seg_end)

                res = pool.apply_async(get_meth_calls_wg, [args, args.bam, args.methdb, chrom, seg_start, seg_end, args.phased, args.mod])

                results.append(res)


    for res in results:
        seg_meth_table = res.get()

        if seg_meth_table is None:
            continue

        if args.phased:
            for phase in (0,1):
                seg_meth_table_store[phase].append(pd.DataFrame.from_dict(seg_meth_table[phase]).T)

        else:
            seg_meth_table_store[0].append(pd.DataFrame.from_dict(seg_meth_table[0]).T)

    if args.phased:
        for phase in (0,1):
            meth_table[phase] = pd.concat(seg_meth_table_store[phase])

    else:
        meth_table[0] = pd.concat(seg_meth_table_store[0])

    if args.phased:
        for phase in (0,1):
            meth_table[phase]['pos'] = meth_table[phase].index
            meth_table[phase] = meth_table[phase].sort_values(['chr', 'pos'])

            outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1])

            if args.chrom:
                outfn += '.%s' % args.chrom

            if args.dss:
                outfn += '.%s.phase_%d.DSS.txt' % (str(args.mod), phase)
            else:
                outfn += '.%s.phase_%d.methyl.bed' % (str(args.mod), phase)

            logger.info('writing %s' % outfn)

            if args.dss:
                meth_table[phase].to_csv(outfn, columns=['chr','pos','N','X'], index=False, sep='\t') # 1-based

            else:
                meth_table[phase]['score']  = meth_table[phase]['N']
                meth_table[phase]['score']  = meth_table[phase]['score'].where(meth_table[phase]['score'] <= 1000, 1000) # cap score at 1000
                meth_table[phase]['start']  = meth_table[phase]['pos']-1 # 0-based
                meth_table[phase]['end']    = meth_table[phase]['pos']
                meth_table[phase]['pct']    = meth_table[phase]['X']/meth_table[phase]['N']
                meth_table[phase]['name']   = '.'.join(args.methdb.split('.')[:-1])
                meth_table[phase]['strand'] = '.'
                meth_table[phase]['colour'] = '0,0,0'

                meth_table[phase].to_csv(outfn, columns=['chr','start','end','name', 'score', 'strand', 'start', 'end', 'colour', 'N', 'pct'], header=False, index=False, sep='\t')


    else:
        meth_table[0]['pos'] = meth_table[0].index
        meth_table[0] = meth_table[0].sort_values(['chr', 'pos'])

        outfn = '.'.join(os.path.basename(args.bam).split('.')[:-1])

        if args.chrom:
            outfn += '.%s' % args.chrom

        if args.dss:
            outfn += '.%s.DSS.txt' % str(args.mod)
        else:
            outfn += '.%s.methyl.bed' % str(args.mod)

        logger.info('writing %s' % outfn)
        if args.dss:
            meth_table[0].to_csv(outfn, columns=['chr','pos','N','X'], index=False, sep='\t')
        else:
            meth_table[0]['score']  = meth_table[0]['N']
            meth_table[0]['score']  = meth_table[0]['score'].where(meth_table[0]['score'] <= 1000, 1000)
            meth_table[0]['start']  = meth_table[0]['pos']-1
            meth_table[0]['end']    = meth_table[0]['pos']
            meth_table[0]['pct']    = meth_table[0]['X']/meth_table[0]['N']*100
            meth_table[0]['name']   = '.'.join(args.methdb.split('.')[:-1])
            meth_table[0]['strand'] = '.'
            meth_table[0]['colour'] = '0,0,0'

            meth_table[0].to_csv(outfn, columns=['chr','start','end','name', 'score', 'strand', 'start', 'end', 'colour', 'N', 'pct'], header=False, index=False, sep='\t')

def main(args):
    logger.info('starting methylartist with command: %s' % ' '.join(sys.argv))
    args.func(args)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='methylartist: tools for exploring nanopore modified base data')
    subparsers = parser.add_subparsers(title="tool", dest="tool")
    subparsers.required = True

    __version__ = "1.2.0"
    parser.add_argument('-v', '--version', action='version', version='%(prog)s {version}'.format(version=__version__))

    parser_nanopolish    = subparsers.add_parser('db-nanopolish')
    parser_megalodon     = subparsers.add_parser('db-megalodon')
    parser_custom        = subparsers.add_parser('db-custom')
    parser_guppy         = subparsers.add_parser('db-guppy')
    parser_segmeth       = subparsers.add_parser('segmeth')
    parser_segplot       = subparsers.add_parser('segplot')
    parser_locus         = subparsers.add_parser('locus')
    parser_region        = subparsers.add_parser('region')
    parser_composite     = subparsers.add_parser('composite')
    parser_wgmeth        = subparsers.add_parser('wgmeth')
    parser_adjustcutoffs = subparsers.add_parser('adjustcutoffs')
    parser_scoredist     = subparsers.add_parser('scoredist')

    parser_segmeth.set_defaults(func=segmeth)
    parser_segplot.set_defaults(func=segplot)
    parser_locus.set_defaults(func=locus)
    parser_region.set_defaults(func=region)
    parser_composite.set_defaults(func=composite)
    parser_wgmeth.set_defaults(func=wgmeth)
    parser_nanopolish.set_defaults(func=db_nanopolish)
    parser_megalodon.set_defaults(func=db_megalodon)
    parser_custom.set_defaults(func=db_custom)
    parser_guppy.set_defaults(func=db_guppy)
    parser_adjustcutoffs.set_defaults(func=adjustcutoffs)
    parser_scoredist.set_defaults(func=scoredist)

    # options for methylation segments
    parser_segmeth.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_segmeth.add_argument('-b', '--bams', default=None, help='one or more .bams with Mm and Ml tags for modification calls (see samtags spec)')
    parser_segmeth.add_argument('-i', '--intervals', required=True, help='.bed file')
    parser_segmeth.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_segmeth.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_segmeth.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_segmeth.add_argument('--excl_ambig', action='store_true', default=False, help='do not consider reads that align entirely within segment')
    parser_segmeth.add_argument('--spanning_only', action='store_true', default=False, help='only consider reads that span segment')
    parser_segmeth.add_argument('--phased', action='store_true', default=False, help='currently only considers two phases (diploid)')

    # options for methylation strip / violin plots
    parser_segplot.add_argument('-s', '--segmeth', required=True, help='output from segmeth.py')
    parser_segplot.add_argument('-m', '--samples', default=None, help='samples, comma delimited')
    parser_segplot.add_argument('-d', '--mods', default=None, help='mods, comma delimited')
    parser_segplot.add_argument('-c', '--categories', default=None, help='categories, comma delimited, need to match seg_name column from input')
    parser_segplot.add_argument('-v', '--violin', default=False, action='store_true')
    parser_segplot.add_argument('-g', '--ridge', default=False, action='store_true')
    parser_segplot.add_argument('-n', '--mincalls', default=10, help='minimum number of calls to include site (methylated + unmethylated) (default=10)')
    parser_segplot.add_argument('-r', '--minreads', default=1, help='minimum reads in interval (default = 1)')
    parser_segplot.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_segplot.add_argument('-a', '--group_by_annotation', default=False, action='store_true', help='group plots by annotation rather than by sample')
    parser_segplot.add_argument('--width', default=None, help='figure width (default = automatic)')
    parser_segplot.add_argument('--height', default=6, help='figure height (default = 6)')
    parser_segplot.add_argument('--pointsize', default=1, help='point size for scatterplot (default = 1)')
    parser_segplot.add_argument('--ymin', default=-0.15, help='ymin (default = -0.15)')
    parser_segplot.add_argument('--ymax', default=1.15, help='ymax (default = 1.15)')
    parser_segplot.add_argument('--ylabel', default='pct methylation', help='set label for y-axis (default: pct methylation)')
    parser_segplot.add_argument('--tiltlabel', default=False, action='store_true')
    parser_segplot.add_argument('--palette', default="tab10", help='palette for phases (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_segplot.add_argument('--ridge_alpha', default=1.0, help='alpha (tranparency) for ridge plot fills (default = 1.0)')
    parser_segplot.add_argument('--ridge_spacing', default=-0.25, help='ridge plot spacing (generally negative, default = -0.25)')
    parser_segplot.add_argument('--svg', default=False, action='store_true')

    # options for locus-specific plots
    parser_locus.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line (whitespace-delimited)')
    parser_locus.add_argument('-b', '--bams', default=None, help='one or more .bams with Mm and Ml tags for modification calls (see samtags spec)')
    parser_locus.add_argument('-i', '--interval', required=True, help='chrom:start-end')
    parser_locus.add_argument('-g', '--gtf', default=None, help='genes or intervals to display in gtf format')
    parser_locus.add_argument('-l', '--highlight', default=None, help='format: start-end, (can be chrom:start-end but chrom is ignored) can comma-delimit multiple highlights')
    parser_locus.add_argument('-m', '--mods', default=None, help='mods, comma-delimited for >1 (default to all available mods)')
    parser_locus.add_argument('-s', '--smoothwindowsize', default=None, help='size of window for smoothing (default=auto)')
    parser_locus.add_argument('-t', '--slidingwindowstep', default=1, help='step size for initial sliding window (default=1)')
    parser_locus.add_argument('-p', '--panelratios',  default=None, help='Alter panel ratios: needs to be 5 comma-seperated integers. Default: 1,5,1,3,3')
    parser_locus.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_locus.add_argument('--highlight_bed', default=None, help='BED3+1 format (chrom, start, end, optional_colour) where colour (optional) must be intelligible to matplotlib')
    parser_locus.add_argument('--motifsize', default=2, help='mod motif size, only used with -b/--bams (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')
    parser_locus.add_argument('--allreads', default=False, action='store_true', help='show all alignments (secondary/supplementary alignments hidden by default)')
    parser_locus.add_argument('--phased', default=False, action='store_true', help='split samples into phases')
    parser_locus.add_argument('--ignore_ps', default=False, action='store_true', help='do not use phase set (PS) when plotting phased data (HP only)')
    parser_locus.add_argument('--color_by_hp', default=False, action='store_true', help='color samples by HP value (req --phased)')
    parser_locus.add_argument('--include_unphased', default=False, action='store_true', help='include an "unphased" category if called with --phased')
    parser_locus.add_argument('--readmask', default=None, help='mask reads from being shown in interval(s) (start-end or chrom:start-end; chrom ignored). Can be comma-delimited.')
    parser_locus.add_argument('--readmarker', default='o', help='marker for (un)methylated glpyhs in read panel (matplotlib markers, default=o)')
    parser_locus.add_argument('--markeralpha', default=1.0, help='alpha (transparency) for (un)methylation marker (default=1.0)')
    parser_locus.add_argument('--readmarkersize', default=2.0, help='marker size for (un)methylated glpyhs in read panel (default=2.0)')
    parser_locus.add_argument('--readlinewidth', default=1.0, help='width for lines representing read alignments (default=1.0)')
    parser_locus.add_argument('--readlinealpha', default=0.5, help='alpha (transparency) for read mapping lines (default=0.4)')
    parser_locus.add_argument('--readopenmarkeredgecolor', default=None, help='edge color for open (unmethylated) markers in read plot (default = sample color)')
    parser_locus.add_argument('--slidingwindowsize', default=2, help='size of initial sliding window for coverage check (default=2)')
    parser_locus.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_locus.add_argument('--smoothalpha', default=1.0, help='alpha (transparency) value for smoothed plot (default = 1.0)')
    parser_locus.add_argument('--maskcutoff', default=1, help='read count masking cutoff (default=1)')
    parser_locus.add_argument('--maxmaskedfrac', default=1.0, help='skip smoothed plot if fraction of sample masked (--maskcutoff) > this value (default = 1.0)')
    parser_locus.add_argument('--nomask', default=False, action='store_true', help='skip drawing segment masks')
    parser_locus.add_argument('--mincalls', default=0, help='drop modspace positions if call count (meth+unmeth) < --mincalls (default=0)')
    parser_locus.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_locus.add_argument('--modspace', default=None, help='spacing between links in top panel (default=auto)')
    parser_locus.add_argument('--genes', default=None, help='genes of interest (comma delimited)')
    parser_locus.add_argument('--labelgenes', default=False, action='store_true', help='plot gene names')
    parser_locus.add_argument('--methcall_ymax', default=None)
    parser_locus.add_argument('--methcall_xmin', default=None)
    parser_locus.add_argument('--methcall_xmax', default=None)
    parser_locus.add_argument('--nticks', default=10, help='tick count (default=10)')
    parser_locus.add_argument('--statname', default=None, help='label for raw statistic plot')
    parser_locus.add_argument('--samplepalette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_locus.add_argument('--highlightpalette', default="Blues", help='colour palette name for highlights (default = "Blues")')
    parser_locus.add_argument('--genepalette', default="viridis", help='colour palette name for highlights (default = "viridis")')
    parser_locus.add_argument('--highlight_panels', default=0.25, help='alpha for highlighting in panels (between 0 and 1, default = 0.25)')
    parser_locus.add_argument('--excl_ambig', action='store_true', default=False)
    parser_locus.add_argument('--unambig_highlight', action='store_true', default=False)
    parser_locus.add_argument('--width', default=16, help='image width (inches, default=16)')
    parser_locus.add_argument('--height', default=8, help='image width (inches, default=8)')
    parser_locus.add_argument('--svg', action='store_true')

    # options for region plots
    parser_region.add_argument('-i', '--interval', required=True, help='chrom:start-end')
    parser_region.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_region.add_argument('-b', '--bams', default=None, help='one or more .bams with Mm and Ml tags for modification calls (see samtags spec)')
    parser_region.add_argument('-n', '--norm_motif', required=True, help='normalise window sizes to motif occurance')
    parser_region.add_argument('-r', '--ref', required=True, help='ref genome fasta, required if normalising windows with -n/--norm_motif')
    parser_region.add_argument('-g', '--gtf', default=None, help='genes or intervals to display in gtf format')
    parser_region.add_argument('-l', '--highlight', default=None, help='format: start-end, (can be chrom:start-end but chrom is ignored) can comma-delimit multiple highlights')
    parser_region.add_argument('-w', '--windows', default=None, help='set window count, default=auto')
    parser_region.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_region.add_argument('-m', '--mods', default=None, help='mods to consider (comma-delimited, default = all available)')
    parser_region.add_argument('-s', '--smoothwindowsize', default=None, help='size of window for smoothing (default=auto)')
    parser_region.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_region.add_argument('--allreads', default=False, action='store_true', help='show all alignments (secondary/supplementary alignments hidden by default)')
    parser_region.add_argument('--highlight_bed', default=None, help='BED3+1 format (chrom, start, end, optional_colour) where colour (optional) must be intelligible to matplotlib')
    parser_region.add_argument('--motifsize', default=2, help='mod motif size, only used with -b/--bams (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')
    parser_region.add_argument('--maxuncovered', default=50.0, help='maximum percentage of uncovered windows tolerated (default = 50.0)')
    parser_region.add_argument('--modspace', default=None, help='increase to increase spacing between links in top panel (default=auto)')
    parser_region.add_argument('--readmask', default=None, help='mask reads from being shown in interval(s) (start-end or chrom:start-end; chrom ignored). Can be comma-delimited.')
    parser_region.add_argument('--min_window_calls', default=1, help='minimum reads per window to include in plot (default = 1)')
    parser_region.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_region.add_argument('--smoothalpha', default=1.0, help='alpha (transparency) value for smoothed plot (default = 1.0)')
    parser_region.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_region.add_argument('--samplepalette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_region.add_argument('--highlightpalette', default="Blues", help='colour palette name for highlights (default = "Blues")')
    parser_region.add_argument('--genepalette', default="viridis", help='colour palette name for highlights (default = "viridis")')
    parser_region.add_argument('--gene_track_height', default=None, help='maximum number of gene track layers')
    parser_region.add_argument('--highlight_panels', default=0.25, help='alpha for highlighting in panels (between 0 and 1, default = 0.25)')
    parser_region.add_argument('--panelratios',  default=None, help='Alter panel ratios: needs to be 4 comma-seperated integers. Default: 1,5,3,3')
    parser_region.add_argument('--nticks', default=10, help='tick count (default=10)')
    parser_region.add_argument('--skip_align_plot', default=False, action='store_true', help='blank alignment plot, useful if unneeded or for runtime.')
    parser_region.add_argument('--force_align_plot', default=False, action='store_true', help='retain alignment plot even over regions > 5Mbp where it would be disabled automatically')
    parser_region.add_argument('--genes', default=None, help='genes of interest (comma delimited)')
    parser_region.add_argument('--labelgenes', default=False, action='store_true', help='plot gene names')
    parser_region.add_argument('--width', default=16, help='image width (inches, default=16)')
    parser_region.add_argument('--height', default=8, help='image width (inches, default=8)')
    parser_region.add_argument('--phased', action='store_true', default=False, help='currently only considers two phases (diploid)')
    parser_region.add_argument('--color_by_hp', default=False, action='store_true', help='color samples by HP value (req --phased)')
    parser_region.add_argument('--svg', action='store_true', default=False)

    # options for composite plots
    parser_composite.add_argument('-d', '--data', default=None, help='text file with .bam filename and corresponding methylation database per line(whitespace-delimited)')
    parser_composite.add_argument('-b', '--bams', default=None, help='one or more .bams with Mm and Ml tags for modification calls (see samtags spec)')
    parser_composite.add_argument('-s', '--segdata', required=True, help='BED3+1: chrom, start, end, strand')
    parser_composite.add_argument('-r', '--ref', required=True, help='ref genome fasta')
    parser_composite.add_argument('-t', '--teref', required=True, help='TE ref fasta')
    parser_composite.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_composite.add_argument('-c','--palette', default="tab10", help='palette for samples (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_composite.add_argument('-a', '--alpha', default=0.3, help='alpha (default: 0.3)')
    parser_composite.add_argument('-w', '--linewidth', default=1, help='line width (default: 1)')
    parser_composite.add_argument('-l', '--lenfrac', default=0.95, help='fraction of TE length that must align (default 0.95)')
    parser_composite.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_composite.add_argument('--meanplot_ylabel', default='% methylation', help='set y-axis label on mean plot')
    parser_composite.add_argument('--mod', default=None, help='modification to plot (mod codes will be listed, default: infer from sample name')
    parser_composite.add_argument('--motif', default='CG', help='modified motif to highlight (default = CG)')
    parser_composite.add_argument('--blocks', default=None, help='blocks to highlight (txt file with start, end, name, hex colour)')
    parser_composite.add_argument('--start', default=None, help='start plotting at this base (default None)')
    parser_composite.add_argument('--end', default=None, help='end plotting at this base (default None)')
    parser_composite.add_argument('--mincalls', default=100, help='minimum call count to include elt (default = 100)')
    parser_composite.add_argument('--maxsegs', default=300, help='maximum elements, if > max random.sample() (default = 300)')
    parser_composite.add_argument('--outelts', default=100, help='maximum output elements, if > max random.sample() (default = 300)')
    parser_composite.add_argument('--slidingwindowsize', default=10, help='size of sliding window for meth frac (default 10)')
    parser_composite.add_argument('--slidingwindowstep', default=1, help='step size for meth frac (default 1)')
    parser_composite.add_argument('--smoothwindowsize', default=8, help='size of window for smoothing (default 8)')
    parser_composite.add_argument('--smoothfunc', default='hanning', help='smoothing function, one of: flat,hanning,hamming,bartlett,blackman (default = hanning)')
    parser_composite.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_composite.add_argument('--excl_ambig', action='store_true', default=False)
    parser_composite.add_argument('--phased', action='store_true', default=False, help='currently only considers two phases (diploid)')
    parser_composite.add_argument('--color_by_phase', default=False, action='store_true', help='color samples by HP value (req --phased)')
    parser_composite.add_argument('--svg', action='store_true', default=False)

    # options for whole genome output
    parser_wgmeth.add_argument('-b', '--bam', required=True, help='bam used for methylation calling')
    parser_wgmeth.add_argument('-d', '--methdb', default=None, help='methylation database')
    parser_wgmeth.add_argument('-s', '--binsize', default=1000000, help='bin size for parallelisation, default = 1000000')
    parser_wgmeth.add_argument('-f', '--fai', required=True, help='fasta index (.fai)')
    parser_wgmeth.add_argument('-m', '--mod', default=None, help='output for specific mod (names vary, see output for hints)')
    parser_wgmeth.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_wgmeth.add_argument('-c', '--chrom', default=None, help='limit analysis to one chromosome')
    parser_wgmeth.add_argument('-q', '--min_mapq', default=10, help='minimum mapping quality (mapq), default = 10')
    parser_wgmeth.add_argument('--max_read_density', default=None, help='filter reads with call density greater >= value, can be helpful in footprinting assays (default=None)')
    parser_wgmeth.add_argument('--dss', default=False, action='store_true', help='output in DSS format (default = bedMethyl)')
    parser_wgmeth.add_argument('--phased', action='store_true', default=False, help='split output into phases (currently just 1,2)')

    # options for custom db
    parser_custom.add_argument('-m', '--methdata', required=True, help='per-read methylation output table')
    parser_custom.add_argument('--header', default=False, action='store_true', help='input table has header')
    parser_custom.add_argument('--delimiter', default=None, help='column delimimter char (default = whitespace (i.e. tab or space)')
    parser_custom.add_argument('--readname', required=True, help='readname column number')
    parser_custom.add_argument('--chrom', required=True, help='chromosome column number')
    parser_custom.add_argument('--pos', required=True, help='genomic (i.e. on chromosome/contig) position column number, 0-based')
    parser_custom.add_argument('--strand', required=True, help='strand column number')
    parser_custom.add_argument('--modprob', required=True, help='column number for probability of modified base')
    parser_custom.add_argument('--canprob', default=None, help='column number for probability of canonical base (if not given, assume p=1-modprob)')
    parser_custom.add_argument('--modbasecol', default=None, help='column number for modified base/motif name (optional, can use --modbase instead)')
    parser_custom.add_argument('--modbase', default=None, help='specify modified base/motif name (overrides --modbasecol)')
    parser_custom.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_custom.add_argument('--minmodprob', default=0.8, help='probability threshold for calling modified base (default = 0.8)')
    parser_custom.add_argument('--mincanprob', default=None, help='probability threshold for calling canonical base (default = minmodprob)')
    parser_custom.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_custom.add_argument('--motifsize', default=2, help='mod motif size (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')

    # options for megalodon db
    parser_megalodon.add_argument('-m', '--methdata', required=True, help='megalodon per_read_text methylation output')
    parser_megalodon.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_megalodon.add_argument('-p', '--minprob', default=0.8, help='probability threshold for calling modified or unmodified base (default = 0.8)')
    parser_megalodon.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_megalodon.add_argument('--motifsize', default=2, help='mod motif size (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')

    # options for guppy db
    parser_guppy.add_argument('-s', '--samplename', required=True, help='name for sample')
    parser_guppy.add_argument('-f', '--fast5', required=True, help='fast5 with called bases')
    parser_guppy.add_argument('-p', '--procs', default=1, help='multiprocessing')
    parser_guppy.add_argument('-m', '--motif', required=True, help='motif e.g. G[A]TC or [C]G')
    parser_guppy.add_argument('-n', '--modname', required=True, help='mod name in guppy fast5 modified base alphabet (5mC, 6mA, etc)')
    parser_guppy.add_argument('-b', '--bam', required=True, help='.bam file containing alignments of reads from fast5')
    parser_guppy.add_argument('-r', '--ref', required=True, help='reference genome fasta (samtools faidx indexed)')
    parser_guppy.add_argument('--minprob', default=0.8, help='probability threshold for calling modified or unmodified base (default = 0.8)')
    parser_guppy.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_guppy.add_argument('--include_unmatched', action='store_true', default=False, help='include sites where read base does not match genome base')
    parser_guppy.add_argument('--motifsize', default=2, help='mod motif size (default is 2 as "CG" is most common use case, e.g. set to 1 for 6mA)')

    # options for nanopolish db
    parser_nanopolish.add_argument('-m', '--methdata', required=True, help='whole genome nanopolish methylation output, can be comma-delimited')
    parser_nanopolish.add_argument('-d', '--db', default=None, help='database name (default: auto-infer)')
    parser_nanopolish.add_argument('-t', '--thresh', default=2.5, help='llr threshold (default = 2.5; if using --scalegroup the suggested setting is 2.0)')
    parser_nanopolish.add_argument('-a', '--append', default=False, action='store_true', help='append to database')
    parser_nanopolish.add_argument('-s', '--scalegroup', default=False, action='store_true', help='scale threshold by number of CpGs in a group')
    parser_nanopolish.add_argument('-n', '--modname', default='CpG', help='modification type (tag if combining multiple mods, default = "CpG")')
    parser_nanopolish.add_argument('--motif', default='CG', help='mod motif (default = CG)')

    # options for adjusting methylation / unmethylation cutoffs in methylartist db
    parser_adjustcutoffs.add_argument('-d', '--db', required=True, help='methylartist database')
    parser_adjustcutoffs.add_argument('--mod', required=True, help='modification to plot (will list for user if incorrect)')
    parser_adjustcutoffs.add_argument('-m', '--methylated', required=True, help='mark as methylated above cutoff value')
    parser_adjustcutoffs.add_argument('-u', '--unmethylated', required=True, help='mark as unmethylated below cutoff value')

    # options for score distribution exploration funciton
    parser_scoredist.add_argument('-d', '--db', required=True, help='methylartist database(s), can be comma-delimited')
    parser_scoredist.add_argument('-n', '--n', default=1000000, help='sample size (default = 1000000)')
    parser_scoredist.add_argument('-m', '--mod', required=True, help='modification to plot (will list for user if incorrect)')
    parser_scoredist.add_argument('--xmin', default=None)
    parser_scoredist.add_argument('--xmax', default=None)
    parser_scoredist.add_argument('--lw', default=2, help='line width (default = 2)')
    parser_scoredist.add_argument('--palette', default="tab10", help='palette for phases (default = "tab10"), see https://seaborn.pydata.org/tutorial/color_palettes.html')
    parser_scoredist.add_argument('--svg', action='store_true', default=False)


    args = parser.parse_args()
    main(args)
