#!/usr/bin/env python

"""
Make a redrock template

This code uses DESI templates and code from

https://github.com/desihub/desispec
https://github.com/desihub/desisim
https://github.com/sbailey/empca

"""
from __future__ import absolute_import, division, print_function

import sys, os
import random
import optparse

import numpy as np
from astropy.io import fits
from empca import empca
import redrock

from desisim.templates import ELG, LRG, BGS
from desispec.interpolation import resample_flux

parser = optparse.OptionParser(usage = "%prog [options]")
parser.add_option("-o", "--outfile", type=str,  help="Output filename")
parser.add_option("--niter", type=int,  help="Number of EMPCA iterations to run [%default]", default=10)
parser.add_option("--nvec", type=int,  help="Number of basis vectors to generate [%default]", default=10)
parser.add_option("--seed", type=int,  help="Seed for desisim.templates.ELG and LRG [%default]", default=123456)
parser.add_option('--version', type=str, help='Template version')

opts, args = parser.parse_args()

#- Check inputs before doing anything else
if opts.outfile is None:
    print('ERROR: Must provide -o/--outfile')
    parser.print_help()
    sys.exit(1)

#- Generate templates and resample to 0.1A grid that covers
#- z=0 to 1.85 for obsframe wavelengths 3500 to 11000, purposefully
#- avoiding Lyman-alpha to focus PCA variations on other lines
dw = 0.1
wave = np.arange(3500/(1+1.85), 11000+dw/2, dw)
nelg = 10000
nlrg = 5000
nbgs = 5000
tflux, twave, tmeta, tobjmeta = ELG().make_templates(nelg, restframe=True, nocolorcuts=True,seed=opts.seed)
elgflux = np.zeros((nelg, wave.size))
for i in range(nelg):
    elgflux[i] = resample_flux(wave, twave, tflux[i])

tflux, twave, tmeta, tobjmeta = LRG().make_templates(nlrg, restframe=True, nocolorcuts=True,seed=opts.seed+nelg)
lrgflux = np.zeros((nlrg, wave.size))
for i in range(nlrg):
    lrgflux[i] = resample_flux(wave, twave, tflux[i])

tflux, twave, tmeta, tobjmeta = BGS().make_templates(nbgs, restframe=True, nocolorcuts=True,seed=opts.seed+nelg+nbgs)
bgsflux = np.zeros((nbgs, wave.size))
for i in range(nbgs):
    bgsflux[i] = resample_flux(wave, twave, tflux[i])

flux = np.vstack([elgflux, lrgflux])
flux = np.vstack([flux, bgsflux])

#- Normalize spectra so that PCA captures variance rather than normalization
for i in range(flux.shape[0]):
    flux[i] /= np.median(flux[i])

#- EMPCA fit
print('RR: Fitting for basis vectors')
model = empca(flux, niter=opts.niter, nvec=opts.nvec)

#- Determine normalized coefficients for all inputs
print('RR: Normalizing coefficients for all inputs')
# model.set_data(data, weights=np.ones_like(data))
for i in range(model.coeff.shape[0]):
    model.coeff[i] /= np.linalg.norm(model.coeff[i])

#- Write output
#- TODO: move to redrock.io.write_template
header = fits.Header()
header['CRVAL1'] = (wave[0], 'restframe starting wavelength [Angstroms]')
header['CDELT1'] = dw
header['RRTYPE']   = 'GALAXY'
header['RRSUBTYP'] = ''
# header['RRINPUT'] = opts.infile
header['RRVER'] = redrock.__version__
header['VERSION'] = (opts.version, 'Template version')
header['INSPEC'] = os.environ['DESI_BASIS_TEMPLATES']
header['SEED'] = opts.seed
header['EXTNAME'] = 'BASIS_VECTORS'

hdus = fits.HDUList()
hdus.append(fits.PrimaryHDU(model.eigvec, header=header))
hdus.append(fits.ImageHDU(model.coeff, name='ARCHETYPE_COEFF'))

hdus.writeto(opts.outfile, overwrite=True)
print('RR: Wrote '+opts.outfile)
