#!/usr/bin/env python

"""
Make redrock White Dwarf templates

This code uses DESI templates and code from

https://github.com/desihub/desispec
https://github.com/desihub/desisim
https://github.com/sbailey/empca

"""
from __future__ import absolute_import, division, print_function

import sys, os
import random
import optparse

import numpy as np
from astropy.io import fits
from empca import empca
import redrock

from desisim.templates import WD
from desispec.interpolation import resample_flux

parser = optparse.OptionParser(usage = "%prog [options]")
parser.add_option("-o", "--outfile", type=str,  help="Output filename")
parser.add_option("--niter", type=int,  help="Number of EMPCA iterations to run [%default]", default=5)
parser.add_option("--nvec", type=int,  help="Number of basis vectors to generate [%default]", default=5)
parser.add_option("--seed", type=int, help="Seed for WD(subtype='..').make_templates [%default]", default=1234567)
parser.add_option('--version', type=str, help='Template version')

opts, args = parser.parse_args()

#- Check inputs before doing anything else
if opts.outfile is None:
    print('ERROR: Must provide -o/--outfile')
    parser.print_help()
    sys.exit(1)

#- Generate templates and resample to 0.1A grid
dw = 0.1
wave = np.arange(3000, 11000+dw/2, dw)
nstar_DA = 10000
nstar_DB = 10000

tflux, twave, meta, tobjmeta = WD(subtype='DA').make_templates(nstar_DA, restframe=True,seed=opts.seed)
flux_DA = np.zeros((nstar_DA, len(wave)))
for i in range(nstar_DA):
    flux_DA[i] = resample_flux(wave, twave, tflux[i])

tflux, twave, meta, tobjmeta = WD(subtype='DB').make_templates(nstar_DB, restframe=True,seed=opts.seed+nstar_DA)
flux_DB = np.zeros((nstar_DB, len(wave)))
for i in range(nstar_DB):
    flux_DB[i] = resample_flux(wave, twave, tflux[i])

flux = np.vstack([flux_DA,flux_DB])

#- Normalize spectra so that PCA captures variance rather than normalization
for i in range(flux.shape[0]):
    flux[i] /= np.median(flux[i])

outbase, outext = os.path.splitext(opts.outfile)

#- EMPCA fit
print('RR: Fitting basis vectors')
mx = model = empca(flux, niter=opts.niter, nvec=opts.nvec)

#- Determine normalized coefficients for all inputs
print('RR: Normalizing coefficients for all inputs')
for i in range(model.coeff.shape[0]):
    model.coeff[i] /= np.linalg.norm(model.coeff[i])

#- Write output
#- TODO: move to redrock.io.write_template
header = fits.Header()
header['CRVAL1'] = (wave[0], 'restframe starting wavelength [Angstroms]')
header['CDELT1'] = dw
header['RRTYPE']   = 'STAR'
header['RRSUBTYP'] = 'WD'
header['RRVER'] = redrock.__version__
header['VERSION'] = (opts.version, 'Template version')
header['INSPEC'] = os.environ['DESI_BASIS_TEMPLATES']
header['SEED'] = opts.seed
header['EXTNAME'] = 'BASIS_VECTORS'

hdus = fits.HDUList()
hdus.append(fits.PrimaryHDU(model.eigvec, header=header))
hdus.append(fits.ImageHDU(model.coeff, name='ARCHETYPE_COEFF'))

outfile = outbase + '-WD' + outext
hdus.writeto(outfile, clobber=True)
print('RR: Wrote '+outfile)
