#!/usr/bin/env python

"""
Redrock benchmark speed tests for zscan
"""

from __future__ import absolute_import, division, print_function
from time import time
t0 = time()
def timeit(message):
    global t0
    print("{:10s} {:.3f}".format(message, time() - t0))
    t0 = time()

import sys, os
from glob import glob
import multiprocessing as mp
import numpy as np
from scipy.sparse import dia_matrix

import redrock as rr
from redrock.dataobj import Spectrum, Target

timeit('import')

#- Wavelengths for 3 arms
wb = np.arange(3570, 5950, 1.0); nb = len(wb)
wr = np.arange(5625, 7740, 1.0); nr = len(wr)
wz = np.arange(7535, 9830, 1.0); nz = len(wz)

#- Fake resolution matrices
xx = np.linspace(-5,5, 15)
yy = np.exp(-xx**2 / 2)
yy /= np.sum(yy)
diags = np.arange(-7, 8)
Rb = dia_matrix( (np.repeat(yy, nb).reshape(15, nb), diags), (nb,nb) )
Rr = dia_matrix( (np.repeat(yy, nr).reshape(15, nr), diags), (nr,nr) )
Rz = dia_matrix( (np.repeat(yy, nz).reshape(15, nz), diags), (nz,nz) )

ivar_b = np.ones(nb)
ivar_r = np.ones(nr)
ivar_z = np.ones(nz)

ntarget = 32
nexp = 2
targets = list()
for i in range(ntarget):
    spectra = list()
    for j in range(nexp):
        spectra.append(Spectrum(wb, np.random.normal(size=nb), ivar_b, Rb))
        spectra.append(Spectrum(wr, np.random.normal(size=nr), ivar_r, Rr))
        spectra.append(Spectrum(wz, np.random.normal(size=nz), ivar_z, Rz))

    targets.append(Target(i, spectra))

tfile = glob(os.getenv('RR_TEMPLATE_DIR') + '/rrtemplate-galaxy*.fits')[0]
templates = rr.io.read_templates(tfile)
redshifts = np.linspace(0.5, 1.5, 100)
timeit('setup')

# results = rr.zscan.calc_zchi2_targets(redshifts, targets, templates[0], verbose=False)
# timeit('zscanA')
# results = rr.zscan.calc_zchi2_targets(redshifts, targets, templates[0], verbose=False)
# timeit('zscanB')

#- Now with multiprocessing
for ncpu in (1, 2, 4, 8, 16, 32):
    for nmkl in (1, 2, 4, 8, 16, 32):
        os.environ['MKL_NUM_THREADS'] = str(nmkl)
        args = list()
        stepsize = ntarget // ncpu
        for i in range(0,ntarget,stepsize):
            args.append( (redshifts, targets[i:i+stepsize], templates[0]) )

        pool = mp.Pool(ncpu)
        t0 = time()
        results = pool.starmap(rr.zscan.calc_zchi2_targets, args)
        timeit('mscan{}-{}'.format(ncpu, nmkl))

sys.exit(1)

#- Non-numba multiprocessing
def cholesky(n):
    A = np.random.uniform(size=(n,n))
    B = A.T.dot(A)
    L = np.linalg.cholesky(B)

def loop(n):
    total = 0.0
    for i in range(n):
        for j in range(n):
            total += np.sin(i*j)

n = 1000
loop(n)
timeit('loop0')
for ncpu in (1, 2, 4, 8):
    args = [[n,] for i in range(ncpu)]
    pool = mp.Pool(ncpu)
    t0 = time()
    results = pool.starmap(loop, args)
    timeit('loop{}'.format(ncpu))

n = 2500
t0 = time()
cholesky(n)
timeit('chol0')

for ncpu in (1, 2, 4, 8):
    args = [[n,] for i in range(ncpu)]
    pool = mp.Pool(ncpu)
    t0 = time()
    results = pool.starmap(cholesky, args)
    timeit('chol{}'.format(ncpu))

