#!/usr/bin/env python


import os
import sys
import argparse
from datetime import datetime
import numpy as np
import pyfasta
import vcfnp


def log(*msg):
    print >>sys.stderr, '[vcf2npy] ' + str(datetime.now()) + ' :: ' + ' '.join(map(str, msg))
    sys.stderr.flush()


def main(vcf_filename,
         fasta_filename,
         output_dir,
         array_type,
         chromosome,
         task_size,
         task_index,
         exclude_fields,
         ploidy,
         dtypes,
         arities,
         progress,
         ):

    assert vcf_filename is not None, 'missing VCF filename, try vcf2npy --help'
    assert os.path.exists(vcf_filename), 'VCF file not found'
    log('loading', array_type, 'from', vcf_filename)

    # determine region to extract
    if chromosome is None:
        log('extract whole genome')
        region = None
    elif task_size is None:
        log('extract whole chromosome', chromosome)
        region = chromosome
    else:
        log('extract a chromosome split', chromosome, task_size, task_index)
        # N.B., make sure regions will work as sortable file names by left-padding with zeros
        assert fasta_filename is not None, 'missing FASTA filename, try vcf2npy --help'
        assert os.path.exists(fasta_filename), 'FASTA file not found'
        genome = pyfasta.Fasta(fasta_filename)
        chrom_size = len(genome[chromosome])
        n_fill = int(np.ceil(np.log10(chrom_size)))
        start = range(0, chrom_size, task_size)[task_index]
        startstr = str(start+1).zfill(n_fill)
        stop = start + task_size
        stopstr = str(stop).zfill(n_fill)
        region = str(chromosome) + ':' + startstr + '-' + stopstr
        log(region)

    if array_type == 'variants':
        A = vcfnp.variants(
            vcf_filename,
            region=region,
            progress=progress,
            exclude_fields=exclude_fields,
            dtypes=dtypes,
            arities=arities,
            flatten_filter=True,
            cache=True,
            cachedir=output_dir,
            skip_cached=True,
            verbose=True,
        )
    elif array_type == 'calldata':
        A = vcfnp.calldata(
            vcf_filename,
            region=region,
            progress=progress,
            ploidy=ploidy,
            exclude_fields=exclude_fields,
            dtypes=dtypes,
            arities=arities,
            cache=True,
            cachedir=output_dir,
            skip_cached=True,
            verbose=True,
        )
    elif array_type == 'calldata_2d':
        A = vcfnp.calldata_2d(
            vcf_filename,
            region=region,
            progress=progress,
            ploidy=ploidy,
            exclude_fields=exclude_fields,
            dtypes=dtypes,
            arities=arities,
            cache=True,
            cachedir=output_dir,
            skip_cached=True,
            verbose=True,
        )
    else:
        raise Exception('unexpected array type: %s' % array_type)
    if A is not None:
        log('loaded', A.size, 'items', A.nbytes, 'bytes')
    log('all done')


if __name__ == '__main__':

    # handle command line options
    epilog = "Version: vcfnp %s" % vcfnp.__version__
    parser = argparse.ArgumentParser(epilog=epilog)
    parser.add_argument('--vcf',
                        dest='vcf_filename', metavar='FILE', default=None,
                        help='input VCF file')
    parser.add_argument('--fasta',
                        dest='fasta_filename', metavar='FASTA', default=None,
                        help='reference genome as FASTA file (only required if --task-size is given)')
    parser.add_argument('--output-dir',
                        dest='output_dir', metavar='DIR', default=None,
                        help='output directory (omit to use default vcfnp cache directory)')
    parser.add_argument('--array-type',
                        dest='array_type', default='variants',
                        help='array type, one of [variants|calldata|calldata_2d]')
    parser.add_argument('--chromosome',
                        dest='chromosome', default=None,
                        help='chromosome to extract (omit to extract whole genome)')
    parser.add_argument('--task-size',
                        dest='task_size', metavar='BP', default=None, type=int,
                        help='size (in base pairs) of region to extract (omit to extract whole chromosome)')
    parser.add_argument('--task-index',
                        dest='task_index', metavar='N', default=None,
                        help='task index as integer or string to get task from environment variable (e.g., "SGE_TASK_ID")')
    parser.add_argument('--exclude-field',
                        dest='exclude_fields', metavar='FIELD', default=None, action='append',
                        help='field to exclude')
    parser.add_argument('--ploidy',
                        dest='ploidy', default=2, type=int,
                        help='sample ploidy')
    parser.add_argument('--dtype',
                        metavar='FIELD:DTYPE', dest='dtypes', action='append',
                        help='override default dtype for a given field, e.g., "MQ:f4"')
    parser.add_argument('--arity',
                        metavar='FIELD:ARITY', dest='arities', action='append',
                        help='override default arity for a given field, e.g., "AD:2"')
    parser.add_argument('--progress',
                        dest='progress', metavar='N', default=10000, type=int,
                        help='log progress every N rows')
    args = parser.parse_args()

    # determine task index
    task_index = None
    if args.chromosome is not None and args.task_size is not None:
        # assume we are running as part of a parallel job array
        # want to determine task index
        if args.task_index is not None:
            try:
                task_index = int(args.task_index)
            except ValueError:
                task_index = str(args.task_index).strip()
                assert task_index in os.environ, 'expected %r in environment' % args.task_index
                task_index = int(os.environ[task_index])
        # if task_index not given in options, assume SGE job array
        elif 'SGE_TASK_ID' in os.environ:
            task_index = int(os.environ['SGE_TASK_ID'])
        else:
            raise Exception('could not determine task index; %s', args)
        task_index -= 1  # use zero-based indices internally

    # determine dtype overrides
    if args.dtypes:
        dtypes = dict(s.split(':') for s in args.dtypes)
    else:
        dtypes = None

    # determine arity overrides
    if args.arities:
        arities = dict((s.split(':')[0], int(s.split(':')[1])) for s in args.arities)
    else:
        arities = None

    main(vcf_filename=args.vcf_filename,
         fasta_filename=args.fasta_filename,
         output_dir=args.output_dir,
         array_type=args.array_type,
         chromosome=args.chromosome,
         task_size=args.task_size,
         task_index=task_index,
         exclude_fields=args.exclude_fields,
         ploidy=args.ploidy,
         dtypes=dtypes,
         arities=arities,
         progress=args.progress,
         )

