#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import click
import gatetools as gt
from gatetools import phsp
import gaga
from torch.autograd import Variable
import torch

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('tlor_filename', nargs=1)
@click.option('--cyl_radius', '-r', required=True, help='Cylindrical radius in mm')
@click.option('--cyl_height', required=True, help='Cylindrical cyl_height in mm')
@click.option('-n', default='-1', help='Number of samples')
@click.option('--shuffle', '-s', is_flag=True, default=False, help='Shuffle the n samples (slow if file is large)')
@click.option('--output', '-o', required=True, help='output filename (npy)')
def go(tlor_filename, cyl_radius, cyl_height, n, output, shuffle):
    """
        Convert a phsp (npy or root) file that contains pairs of particles parameterized with tlor
        into phsp with pairs of particles.
        
        tlor : Cx Cy Cz Vx Vy Vz Dx Dy Dz Wx Wy Wz t1 t2 t3 E1 E2
        pairs: t1 t2 Ax Ay Az Bx By Bz dAx dAy dAz dBx dBy dBz E1 E2
        
        tlor is within a cylinder (radius, height)
        
    """

    # read data
    n = int(float(n))
    phsp, keys, m = gt.phsp.load(tlor_filename, nmax=n, shuffle=shuffle)
    if n == -1:
        n = m
    print('Input ', tlor_filename, n, keys)

    # parameters for conversion
    params = {
        'cyl_radius': float(cyl_radius),
        'cyl_height': float(cyl_height),
        'keys_list': keys
    }

    # convert numpy to torch --> this function may be embedded in the Generator for Gate
    # (I need to copy otherwise : 'warning NumPy array is not writeable'
    dtypef, device = gaga.init_pytorch_cuda('auto', verbose=False)
    phsp = Variable(torch.from_numpy(phsp.copy()).type(dtypef))

    # convert tlor to pairs
    x = gaga.from_tlor_to_pairs(phsp, params)
    keys_out = params['keys_output']
    print(f'Output Keys {keys_out}')
    print(f'Output conversion: {len(x)}/{n}')

    # convert back to numpy and save
    x = x.cpu().data.numpy()
    gt.phsp.save_npy(output, x, keys_out)


# --------------------------------------------------------------------------
if __name__ == '__main__':
    go()
