#!/usr/bin/env python

"""
Make redrock stellar templates

This code uses DESI templates and code from

https://github.com/desihub/desispec
https://github.com/desihub/desisim
https://github.com/sbailey/empca

"""
from __future__ import absolute_import, division, print_function

import sys, os
import random
import optparse

import numpy as np
from astropy.io import fits
from empca import empca
import redrock

from desisim.templates import STAR
from desispec.interpolation import resample_flux

parser = optparse.OptionParser(usage = "%prog [options]")
parser.add_option("-o", "--outfile", type=str,  help="Output filename")
parser.add_option("--niter", type=int,  help="Number of EMPCA iterations to run [%default]", default=5)
parser.add_option("--nvec", type=int,  help="Number of basis vectors to generate [%default]", default=5)
parser.add_option("--seed", type=int,  help="Seed for desisim.templates.STAR [%default]", default=12345)
parser.add_option('--version', type=str, help='Template version')

opts, args = parser.parse_args()

#- Check inputs before doing anything else
if opts.outfile is None:
    print('ERROR: Must provide -o/--outfile')
    parser.print_help()
    sys.exit(1)

#- Generate templates and resample to 0.1A grid
dw = 0.1
wave = np.arange(3000, 11000+dw/2, dw)
nstar = 10000
tflux, twave, meta, tobjmeta = STAR().make_templates(nstar, restframe=True,seed=opts.seed)
flux = np.zeros((nstar, wave.size))
for i in range(nstar):
    flux[i] = resample_flux(wave, twave, tflux[i])

#- Normalize spectra so that PCA captures variance rather than normalization
for i in range(flux.shape[0]):
    flux[i] /= np.median(flux[i])

#- Process stellar types in separate PCA
typetemp = dict(
    #- DESI stellar templates don't cover O-type with Teff>30k; start with B
    B = [10000, 30000],
    A = [7500, 10000],
    F = [6000, 7500],
    G = [5200, 6000],
    K = [3700, 5200],
    M = [2400, 3700],
)

outbase, outext = os.path.splitext(opts.outfile)
mx = dict()
for spectype, (mintemp, maxtemp) in typetemp.items():
    ii = (mintemp <= tobjmeta['TEFF']) & (tobjmeta['TEFF'] <= maxtemp)
    print('RR: Using {} spectra to get the templates'.format(ii.sum()))

    #- EMPCA fit
    print('RR: Fitting basis vectors for spectral type {} with {} templates'.format(spectype, np.count_nonzero(ii)))
    mx[spectype] = model = empca(flux[ii], niter=opts.niter, nvec=opts.nvec)

    #- Determine normalized coefficients for all inputs
    print('RR: Normalizing coefficients for all inputs')
    # model.set_data(data, weights=np.ones_like(data))
    for i in range(model.coeff.shape[0]):
        model.coeff[i] /= np.linalg.norm(model.coeff[i])

    #- Write output
    #- TODO: move to redrock.io.write_template
    header = fits.Header()
    header['CRVAL1'] = (wave[0], 'restframe starting wavelength [Angstroms]')
    header['CDELT1'] = dw
    header['RRTYPE']   = 'STAR'
    header['RRSUBTYP'] = spectype
    # header['RRINPUT'] = opts.infile
    header['RRVER'] = redrock.__version__
    header['VERSION'] = (opts.version, 'Template version')
    header['INSPEC'] = os.environ['DESI_BASIS_TEMPLATES']
    header['SEED'] = opts.seed
    header['EXTNAME'] = 'BASIS_VECTORS'

    hdus = fits.HDUList()
    hdus.append(fits.PrimaryHDU(model.eigvec, header=header))
    hdus.append(fits.ImageHDU(model.coeff, name='ARCHETYPE_COEFF'))

    outfile = outbase + '-' + spectype + outext
    hdus.writeto(outfile, overwrite=True)
    print('RR: Wrote '+outfile)
