#!/usr/bin/env python


from __future__ import division
import os
import sys
import argparse
import numpy as np
import h5py
import vcflib
from glob import glob
from datetime import datetime


def log(*msg):
    print >>sys.stderr, '[vcfnpy2hdf5] ' + str(datetime.now()) + ' :: ' + ' '.join(map(str, msg))
    sys.stderr.flush()


def load_compound_arrays_into_dataset(input_filenames, dataset_name, parent_group, chunk_size, compression, compression_opts):

    a = np.load(input_filenames[0])

    if dataset_name in parent_group:
        log(dataset_name, 'delete existing dataset')
        del parent_group[dataset_name]

    log(dataset_name, 'setup dataset')
    shape = (0,) + a.shape[1:]  # initially empty
    maxshape = (None,) + a.shape[1:]  # resizable along first dimension
    if a.ndim == 1:
        chunks = chunk_size//a.dtype.itemsize,
    else:
        # optimise for access along second dimension
        chunks = (chunk_size//a.dtype.itemsize, 1) + a.shape[2:]
    log(dataset_name, 'dtype', a.dtype, 'shape', shape, 'maxshape', maxshape, 'chunks', chunks)
    ds = parent_group.create_dataset(dataset_name, shape=shape, dtype=a.dtype,
                                     chunks=chunks, maxshape=maxshape,
                                     compression=compression,
                                     compression_opts=compression_opts)

    log(dataset_name, 'load data')
    for fn in input_filenames:
        a = np.load(fn)
        if a.size > 0:
            n = ds.shape[0]
            n_add = a.shape[0]
            n_new = n + n_add
            log(dataset_name, 'loading', n, n_new)
            ds.resize(n_new, axis=0)
            ds[n:n_new] = a

    return ds


def load_compound_arrays_into_group(input_filenames, group_name, parent_group, chunk_size, compression, compression_opts):

    a = np.load(input_filenames[0])
    grp = parent_group.require_group(group_name)

    log(group_name, 'setup datasets')
    for f in a.dtype.names:

        if f in grp:
            log(group_name, f, 'deleting existing dataset')
            del grp[f]

        shape = (0,) + a[f].shape[1:]  # initially empty
        maxshape = (None,) + a[f].shape[1:]  # resizable along first dimension
        if a[f].ndim == 1:
            chunks = chunk_size//a[f].dtype.itemsize,
        else:
            # optimise for access along second dimension
            chunks = (chunk_size//a[f].dtype.itemsize, 1) + a[f].shape[2:]
        log(group_name, f, 'dtype', a[f].dtype, 'shape', shape, 'maxshape', maxshape, 'chunks', chunks)
        grp.create_dataset(f, shape=shape, dtype=a[f].dtype,
                           chunks=chunks, maxshape=maxshape,
                           compression=compression,
                           compression_opts=compression_opts)

    log(group_name, 'load data')
    for fn in input_filenames:
        a = np.load(fn)
        if a.size > 0:
            for f in a.dtype.names:
                n = grp[f].shape[0]
                n_add = a.shape[0]
                n_new = n + n_add
                log(group_name, f, 'loading', n, n_new)
                grp[f].resize(n_new, axis=0)
                grp[f][n:n_new, ...] = a[f]

    return grp


def load_hdf5(input_dir,
              output_filename,
              input_filename_template='{array_type}*.npy',
              vcf_filename=None,
              variants_only=False,
              tabulate_variants=False,
              parent_group_name='/',
              chunk_size=100000,
              compression=None,
              compression_opts=None):

    # guard conditions
    assert input_dir is not None and os.path.exists(input_dir)
    if vcf_filename is not None:
        assert os.path.exists(vcf_filename)

    # template for input arrays
    input_path_template = os.path.join(input_dir, input_filename_template)

    log('open hdf5 file', output_filename)
    with h5py.File(output_filename, 'a') as h5f:

        log('setup parent group', parent_group_name)
        parent_group = h5f.require_group(parent_group_name)

        if vcf_filename is not None:
            log('store metadata from VCF', vcf_filename)
            vcf = vcflib.PyVariantCallFile(vcf_filename)
            if 'samples' in parent_group:
                del parent_group['samples']
            samples = np.array(vcf.sampleNames)
            parent_group.create_dataset('samples', data=samples)

        log('process variants')
        variants_input_filenames = sorted(glob(input_path_template.format(array_type='variants')))
        if tabulate_variants:
            log('use dataset layout for variants')
            variants = load_compound_arrays_into_dataset(variants_input_filenames, 'variants', parent_group,
                                                         chunk_size=chunk_size,
                                                         compression=compression,
                                                         compression_opts=compression_opts)
        else:
            log('use group layout for variants')
            variants = load_compound_arrays_into_group(variants_input_filenames, 'variants', parent_group,
                                                       chunk_size=chunk_size,
                                                       compression=compression,
                                                       compression_opts=compression_opts)

        if not variants_only:

            log('process calldata')
            calldata_input_filenames = sorted(glob(input_path_template.format(array_type='calldata_2d')))

            log('check number of input files')
            nvf = len(variants_input_filenames)
            nc2df = len(calldata_input_filenames)
            assert nvf == nc2df, 'bad number of calldata files, expected %s, found %s' % (nvf, nc2df)

            calldata = load_compound_arrays_into_group(calldata_input_filenames, 'calldata', parent_group,
                                                       chunk_size=chunk_size,
                                                       compression=compression,
                                                       compression_opts=compression_opts)

        log('all done')


def main():

    # handle command line args
    parser = argparse.ArgumentParser()
    parser.add_argument('--vcf',
                        dest='vcf_filename', metavar='VCF', default=None,
                        help='VCF file to extract metadata from')
    parser.add_argument('--input-dir',
                        dest='input_dir', metavar='DIR', default=None,
                        help='input directory containing npy files')
    parser.add_argument('--input-filename-template',
                        dest='input_filename_template', metavar='TEMPLATE',
                        default='{array_type}*.npy',
                        help='template for input file names, defaults to "{array_type}*.npy"')
    parser.add_argument('--output',
                        dest='output_filename', metavar='HDF5', default=None,
                        help='name of output HDF5 file')
    parser.add_argument('--group',
                        metavar='GROUP', dest='group', default='/',
                        help='destination group in HDF5 file, defaults to root group')
    parser.add_argument('--chunk-size',
                        dest='chunk_size', type=int, metavar='NBYTES', default=100000,
                        help='chunk size (defaults to 100kb)')
    parser.add_argument('--compression',
                        dest='compression', metavar='NAME', default=None,
                        help='compression, default is None')
    parser.add_argument('--compression-opts',
                        dest='compression_opts', metavar='LEVEL', type=int, default=None,
                        help='compression level, applies only to gzip')
    parser.add_argument('--variants-only',
                        dest='variants_only', action='store_true', default=False,
                        help="load variants only, don't look for calldata")
    parser.add_argument('--tabulate-variants',
                        dest='tabulate_variants', action='store_true', default=False,
                        help="organise the variants as a dataset with a compound dtype")
    args = parser.parse_args()

    load_hdf5(input_dir=args.input_dir,
              output_filename=args.output_filename,
              input_filename_template=args.input_filename_template,
              parent_group_name=args.group,
              chunk_size=args.chunk_size,
              compression=args.compression,
              compression_opts=args.compression_opts,
              variants_only=args.variants_only,
              tabulate_variants=args.tabulate_variants)


if __name__ == '__main__':
    main()
