#!/usr/bin/env python


import os
import sys
import argparse
import numpy as np
import h5py
import vcflib
from glob import glob
from datetime import datetime


def log(*msg):
    print >>sys.stderr, '[vcfnpy2hdf5] ' + str(datetime.now()) + ' :: ' + ' '.join(map(str, msg))
    sys.stderr.flush()


def main(args):

    # guard conditions
    assert args.input_dir is not None and os.path.exists(args.input_dir)
    assert args.vcf_filename is not None and os.path.exists(args.vcf_filename)

    # template for input arrays
    input_path_template = os.path.join(args.input_dir, args.input_filename_template)

    log('open hdf5 file', args.output_filename)
    h5f = h5py.File(args.output_filename, 'a')

    log('store metadata from VCF', args.vcf_filename)
    vcf = vcflib.PyVariantCallFile(args.vcf_filename)
    samples = vcf.sampleNames
    h5f.attrs['samples'] = samples

    log('setup output group', args.group)
    group = h5f.require_group(args.group)

    log('process variants')
    #######################

    if 'variants' in group:
        log('deleting existing variants dataset')
        del group['variants']
    variants_input_filenames = sorted(glob(input_path_template.format(array_type='variants')))

    log('determine dtype')
    A = np.load(variants_input_filenames[0])
    dtype = A.dtype
    log(dtype)

    log('determine chunk shape')
    chunks = args.chunk_size/dtype.itemsize,
    log(chunks)

    log('create dataset')
    shape = 0,  # initially empty
    maxshape = None,  # resizable
    variants = group.create_dataset('variants', shape=shape, dtype=dtype,
                                    chunks=chunks, maxshape=maxshape,
                                    compression=args.compression,
                                    compression_opts=args.compression_opts)

    log('load data')
    for fn in variants_input_filenames:
        A = np.load(fn)
        if A.size > 0:
            n = variants.shape[0]
            n_add = A.shape[0]
            n_new = n + n_add
            log('loading', n, n_new)
            variants.resize(n_new, axis=0)
            variants[n:n_new] = A
    log('loaded', variants.size)

    if 'POS' in group:
        log('deleting existing POS dataset')
        del group['POS']
    log('create POS dataset')
    group.create_dataset('POS', data=variants['POS'], chunks=None, compression=None)

    if not args.variants_only > 0:

        log('process calldata')
        #######################

        calldata_2d = group.require_group('calldata_2d')
        calldata_2d_input_filenames = sorted(glob(input_path_template.format(array_type='calldata_2d')))
        nvf = len(variants_input_filenames)
        nc2df = len(calldata_2d_input_filenames)
        # check there is one calldata_2d file for each variants file
        assert nvf == nc2df, 'bad number of calldata_2d files, expected %s, found %s' % (nvf, nc2df)

        log('determine dataset names')
        A = np.load(calldata_2d_input_filenames[0])
        dtype = A.dtype
        log(dtype.names)

        for f in dtype.names:
            t = dtype.fields[f][0]
            log(f, 'dtype', t)
            shape = variants.size, len(samples)
            log(f, 'shape', shape)
            # try to get ~100kb chunks, optimise for accessing column-wise
            chunks = args.chunk_size/t.itemsize, 1
            log(f, 'chunkshape', chunks)
            if chunks[0] > shape[0]:
                chunks = shape
            if f in calldata_2d:
                log(f, 'deleting existing dataset')
                del calldata_2d[f]
            log(f, 'create dataset')
            calldata_2d.create_dataset(f, shape=shape, dtype=t, chunks=chunks,
                                       compression=args.compression,
                                       compression_opts=args.compression_opts)

        log('load data')
        start = 0
        for fn in calldata_2d_input_filenames:
            A = np.load(fn)
            if A.size > 0:
                stop = start + A.shape[0]
                log('loading', start, stop)
                for f in dtype.names:
                    calldata_2d[f][start:stop, ...] = A[f]
                start = stop

    log('all done')


if __name__ == '__main__':

    # handle command line args
    parser = argparse.ArgumentParser()
    parser.add_argument('--vcf',
                        dest='vcf_filename', metavar='VCF', default=None,
                        help='VCF file to extract metadata from')
    parser.add_argument('--input-dir',
                        dest='input_dir', metavar='DIR', default=None,
                        help='input directory containing npy files')
    parser.add_argument('--input-filename-template',
                        dest='input_filename_template', metavar='TEMPLATE',
                        default='{array_type}*.npy',
                        help='template for input file names, defaults to "{array_type}*.npy"')
    parser.add_argument('--output',
                        dest='output_filename', metavar='HDF5', default=None,
                        help='name of output HDF5 file')
    parser.add_argument('--group',
                        metavar='GROUP', dest='group', default='/',
                        help='destination group in HDF5 file, defaults to root group')
    parser.add_argument('--chunk-size',
                        dest='chunk_size', type=int, metavar='NBYTES', default=100000,
                        help='chunk size (defaults to 100kb)')
    parser.add_argument('--compression',
                        dest='compression', metavar='NAME', default=None,
                        help='compression, default is None')
    parser.add_argument('--compression-opts',
                        dest='compression_opts', metavar='LEVEL', type=int, default=None,
                        help='compression level, applies only to gzip')
    parser.add_argument('--variants-only',
                        dest='variants_only', action='store_true', default=False,
                        help="load variants only, don't look for calldata")
    args = parser.parse_args()
    main(args)
