#!/usr/bin/env python
"""
mbank_injfile
-------------

A script to create an xml file with random injections.
If a bank is given, they will be randomly chosen among the templates.
If a tiling obj is given, they will be randomly drawn within each tile.

To perform injections:

	mbank_injfile --options-you-like

You can also load (some) options from an ini-file:

	mbank_injfile --some-options other_options.ini

Make sure that the mbank is properly installed.
To know which options are available:

	mbank_injfile --help
"""
#TODO: make this file better, with support for boundaries and stuff :D
import numpy as np
import matplotlib.pyplot as plt
import sys
import warnings

from mbank import cbc_bank, variable_handler
from mbank.flow import STD_GW_Flow
from mbank.utils import  save_injs
from mbank.parser import updates_args_from_ini, add_range_options

import argparse
import os

##### Creating parser
parser = argparse.ArgumentParser(__doc__)
add_range_options(parser)
parser.add_argument(
	"--run-dir", default = None,
	help="Run directory. Unless explicitly stated, every input file will be understood to be in this run-dir")
parser.add_argument(
	"--flow-file", required = False, type=str,
	help="A flow file to draw the injections from. If not given, the injection will be read from the given --bank-file")
parser.add_argument(
	"--bank-file", required = False, type=str,
	help="The path to the bank file to pick injection from. Ignored if --tiling-file is set ")
parser.add_argument(
	"--inj-out-file", required = False, type=str,
	help="Injection file (in xml format) in wich the injections are saved")
parser.add_argument(
	"--variable-format", required = False,
	help="Choose which variables to include in the bank. Valid formats are those of `mbank.handlers.variable_format`")
parser.add_argument(
	"--gps-start",  required = False, type=int,
	help="Start GPS time for the injections")
parser.add_argument(
	"--gps-end",  required = False, type=int,
	help="End GPS time for the injections")
parser.add_argument(
	"--time-step",  required = False, type=float,
	help="Distance in time between consecutive injections")
parser.add_argument(
	"--f-min",  default = 10., type=float,
	help="Minium frequency (in Hz) for the injections")
parser.add_argument(
	"--f-max",  default = 1024., type=float,
	help="Maximum frequency (in Hz) for the injections")
parser.add_argument(
	"--distance-range", default = [100., 1000.], type=float, nargs = 2,
	help="Luminosity distance for all the injections (in Mpc)")
parser.add_argument(
	"--approximant", default = 'IMRPhenomPv2threePointFivePN',
	help="LAL approximant for the injection generation")
parser.add_argument(
	"--fixed-sky-loc-polarization", type = float, nargs = 3, default = None,
	help="Sky localization and polarization angles for the signal injections. They must be a tuple of float in the format (longitude,latitude,polarization). If None, the angles will be loaded from the injection file, if given, or uniformly drawn from the sky otherwise.")

args, filenames = parser.parse_known_args()

	#updating from the ini file(s), if it's the case
for f in filenames:
	args = updates_args_from_ini(f, args, parser)

####################
#Interpreting the parser and calling the function

if  (args.inj_out_file is None) or (args.variable_format is None) or (args.gps_start is None) or (args.gps_end is None) or (args.time_step is None):
	raise ValueError("The arguments inj-out-file, variable-format, gps-start, gps-end and time-step must be set!")

assert args.variable_format in variable_handler().valid_formats, "Wrong value {} for variable-format".format(args.variable_format)

if not args.approximant.endswith('PN'):
	warnings.warn("GstLAL wants the injection approximant name to end with a the required PN order, e.g. IMRPhenomPv2 -> IMRPhenomPv2threePointFivePN . Are you sure you want to omit it?")
else:
	print("Using approximant ", args.approximant)

if not args.run_dir.endswith('/'): args.run_dir = args.run_dir+'/'
if not os.path.exists(args.run_dir):
	raise ValueError("The given run directory --run-dir '{}' does not exist".format(args.run_dir))
if args.flow_file is not None:
	if args.flow_file.find('/') <0: args.flow_file = args.run_dir+args.flow_file
if args.bank_file is not None:
	if args.bank_file.find('/') <0: args.bank_file = args.run_dir+args.bank_file
if args.inj_out_file.find('/') <0: args.inj_out_file = args.run_dir+args.inj_out_file

min_dist, max_dist = args.distance_range
rng = np.random.default_rng()

if args.flow_file is not None:
	print("Generating random tile injections from the given tiling file: {}".format(args.flow_file))
	flow = STD_GW_Flow.load_flow(args.flow_file)

	N_injs = int((args.gps_end- args.gps_start)/args.time_step)

	import torch
	with torch.no_grad():
		injs = flow.sample(N_injs).numpy() #(N_injs, D)

elif args.bank_file is not None:
	print("Loading injections from the given bank: {}".format(args.bank_file))
	bank = cbc_bank(args.variable_format, args.bank_file)

		#shuffling the bank
	np.random.shuffle(bank.templates)

	N_templates = int((args.gps_end-args.gps_start)/args.time_step)
	ids_ = rng.choice(bank.templates.shape[0], N_templates)
	
	injs = bank.templates[ids_, :] #(N_templates, D)

else:
	raise ValueError("Either --tiling-file or --bank-file options must be non empty")

	#if sky_locs is None, injections will be drawn at random from the sky inside save_injs
sky_locs = np.array(args.fixed_sky_loc_polarization) if args.fixed_sky_loc_polarization else None

var_handler = variable_handler()
injs = var_handler.get_BBH_components(injs, args.variable_format)

save_injs(args.inj_out_file, injs, args.gps_start, args.gps_end, args.time_step, args.approximant,
	luminosity_distance = (min_dist, max_dist), sky_locs = sky_locs, f_min = args.f_min, f_max = args.f_max)





















