#!/usr/bin/env python3
###########################################################################################
#  package:   pNbody
#  file:      mkgmov
#  copyright: GPLv3
#             Copyright (C) 2019 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Yves Revaz <yves.revaz@epfl.ch>
#
# This file is part of pNbody.
###########################################################################################


"""

# do not reload if the file is similar


"""


import sys
import os
import string
import getopt
import types

from numpy import *
from pNbody import *
from pNbody import Movie
from pNbody.param import *
from pNbody.libutil import *
from pNbody import iofunc as io
import copy


SAVEDPARAMS = 'saved_parameters.py'
global gparam


##########################################################################
def read_params(paramname):
    ##########################################################################
    """
    read param from a parameter Nbody file
    """

    gparams = Params(paramname, None)

    # create new params
    defparam = {}
    for param in gparams.params:
        param_name = param[0]
        param_valu = param[3]

        defparam[param_name] = param_valu

    # aditional parameters
    defparam['exec'] = None
    defparam['macro'] = None
    defparam['n1'] = None
    defparam['n2'] = None
    defparam['select'] = None
    defparam['ftype'] = None

    defparam['subs'] = None
    defparam['fdir'] = None
    defparam['tdir'] = None  # track directory
    defparam['pfile'] = None

    return defparam


##########################################################################
def write_params(file, nh, nw, width, height, params):
    ##########################################################################

    fd = open(file, 'w')
    fd.write("""
nh = %d  		# number of horizontal frame
nw = %d			# number of vertical frame
# size of subfilms
width = %d
height = %d
# size of the film
numByte = width * nw
numLine = height * nh
# init parameters
param = initparams(nh,nw)\n""" % (nh, nw, width, height))

    # take params from the first file
    keys = sorted(params.keys())

    params = params[keys[0]]

    for i in list(param.keys()):
        fd.write('\n')
        for key in param[i]:

            name = key
            tpe = gparams.get_type(key)
            value = param[i][key]
            string = write_ascii_value(value, tpe, name)

            if tpe == "String" and value is not None:
                line = """param[%d]['%s'] = "%s"\n""" % (i, name, string)
            else:
                line = """param[%d]['%s'] = %s\n""" % (i, name, string)

            fd.write(line)

    fd.close()


##########################################################################
def initparams(nh, nw):
    ##########################################################################
    """

    create a dictionary "param"    and full it with nh*nw empty dictionary
    create a dictionary "allparam" and full it with nh*nw empty dictionary

    allparam = all parameters (from file nbodyparams and from file .py)
    param    = parameters from file .py
    """

    global gparams
    global allparam

    param = {}
    allparam = {}

    n = 1
    for i in range(nh * nw):
        param[n] = {'tdir': None}
        param[n] = {'pfile': None}
        allparam[n] = {}
        n = n + 1

    return param


##########################################################################
def version():
    ##########################################################################
    print('version 2.0')
    sys.exit(0)

##########################################################################


def help_message():
    ##########################################################################
    print("""Usage : mkgmov [option] output files
  Options: -h        -- this help message
  	   -p	     -- parameter file
	   -f	     -- g-parameter file
	   -s	     -- disable the softening of rsp
	   -c	     -- enable auto cd for each image
	   -z        -- convert time in redshift
	   -u        -- do not reload file
	   --fits    -- create fits output instead of .gmv film
	   --fitsdir -- directory where to save fits files
	   --info    -- give the optimal factor for each files in the list
           --help    -- this help message
           --version -- displays version

    """)
    sys.exit(0)

##########################################################################


def check_arguments(options, xarguments):
    ##########################################################################

    param = None
    gparam = None
    verbose = None
    info = None
    pio = 'no'
    sofrsp = 1
    arange = 0
    redshift = 0
    readonlyonce = 0
    mkfits = 0
    fitsdir = "fits"

    for a in options[:]:
        if a[0] == '-h':
            help_message()

        if a[0] == '--help':
            help_message()

        if a[0] == '--version':
            version()

        if a[0] == '--info':
            info = 1

        if a[0] == '--pio':
            pio = "yes"

        if a[0] == '-v':
            verbose = 1

        if a[0] == '-s':
            sofrsp = 0

        if a[0] == '-c':
            arange = 1

        if a[0] == '-z':
            redshift = 1

        if a[0] == '-u':
            readonlyonce = 1

        if a[0] == '--fits':
            mkfits = 1

        if a[0] == '--fitsdir':
            if a[1] == '':
                help_message()
            else:
                fitsdir = a[1]
                continue

        if a[0] == '-p':
            if a[1] == '':
                help_message()
            else:
                param = a[1]
                continue

        if a[0] == '-f':
            if a[1] == '':
                help_message()
            else:
                gparam = a[1]
                continue

    try:
        output = xarguments[0]
    except BaseException:
        help_message()

    files = sorted(xarguments[1:])

    return files, output, param, gparam, verbose, info, pio, sofrsp, arange, redshift, readonlyonce, mkfits, fitsdir


##########################################################################
#
#  MAIN
#
##########################################################################


if mpi.mpi_IsMaster():

    try:
        options, xarguments = getopt.getopt(sys.argv[1:], 'p:f:hvscfzu', [
                                            'info', 'pio', 'fits', 'fitsdir', 'help', 'version'])
    except getopt.error:
        help_message()
        sys.exit(0)

    # check arguments
    files, output, param_file, gparam_file, verbose, info, pio, sofrsp, arange, redshift, readonlyonce, mkfits, fitsdir = check_arguments(
        options, xarguments)

    nbody_parameter_file = gparam_file

    # verifie que output n'existe pas
    complete = 0
    if not info:
        if not mkfits:
            if (os.path.exists(output) != 0):

                answer = input(
                    '%s exists. Remove it (r) Continue it (c) Exit (q) ? ' %
                    (output))

                if len(answer) == 0 or answer[0] == 'c':
                    if not os.path.exists(SAVEDPARAMS):
                        print("Warning : %s does not exists." % (SAVEDPARAMS))
                        print("Using   : %s instead." % (SAVEDPARAMS))
                    complete = 1
                    param_file = SAVEDPARAMS
                elif answer[0] == 'r':
                    os.remove(output)
                else:
                    sys.exit()

    ##########################################################################
    # a partir de snapshots ou de files, cree une liste de files
    ##########################################################################

    ##############################################################
    # lit le fichier param s'il existe, sinon, valeurs par defaut

    if param_file is None:
        # nombre de sous film horiz et vertical
        nh = 1
        nw = 1
        # size of subfilms
        width = 128
        height = width / 2 + width / 4
        # info on subfilms
        param = initparams(nh, nw)
    elif os.path.exists(param_file):
        exec(compile(open(param_file).read(), param_file, 'exec'))
    else:
        print("Error : %s does not exists." % (param_file))
        sys.exit()

    ##############################################################
    # compleete allparam with default values
    ##############################################################

    if nbody_parameter_file is not None:
        paramname = nbody_parameter_file
    else:
        paramname = PARAMETERFILE

    # read the default parameters
    # paramname = PARAMETERFILE				  # ! ! ! not that good ! ! !
    gparams = Params(paramname, None)

    defparam = read_params(paramname)			  # ! ! ! read again	! ! !

    # compleete allparam
    n = 1
    for i in range(nh * nw):
        allparam[n] = copy.deepcopy(defparam)
        n = n + 1

    #############################################
    # read from pfile files if exists

    n = 1
    # loop over all sub-images
    for i in range(nh * nw):

        n = i + 1

        if param[n]['pfile'] is not None:

            paramfile = param[n]['pfile']

            if not os.path.isfile(paramfile):
                print("file %s does not exists..." % (paramfile))
                sys.exit()

            allparam[n] = read_params(paramfile)

    ##############################################################
    # create params with allparam and add paramfiles
    ##############################################################

    params = {}
    listfiles = []

    basefiles = []
    for file in files:
        basefiles.append(os.path.basename(file))

    #############################################
    # read from trail files if exists

    n = 1
    # loop over all sub-images and verify that tdir exists
    for i in range(nh * nw):
        if 'tdir' not in param[n]:
            param[n]['tdir'] = None
        n = n + 1

    n = 1
    # loop over all sub-images and verify that time exists
    for i in range(nh * nw):
        if 'time' not in param[n]:
            txt = "You must define the parameter time : ex param[1]['time'] = 'nb.time'"
            raise Exception("parameter Error", txt)
        n = n + 1

    # loop over all sub-images
    n = 1
    for i in range(nh * nw):

        # here, we use param (init with .py)
        if param[n]['tdir'] is not None:

            directory = param[n]['tdir']
            if os.path.isdir(directory):
                paramfiles = sorted(glob.glob('%s/*' % (directory)))

                # loop over files
                for paramfile in paramfiles:

                    snap_file = os.path.basename(paramfile)
                    snap_file = os.path.splitext(snap_file)[0]

                    try:
                        index = basefiles.index(snap_file)
                    except ValueError:
                        print("file %s not in file list" % (snap_file))
                        sys.exit()

                    listfiles.append(files[index])

                # stop after first time here
                break

    if len(listfiles) != 0:
        files = listfiles

    # now, loop over files
    old_name = ''
    num = -1
    listfiles = []

    for file in files:

        # set name and redefine files (listfiles)
        name = os.path.basename(file)
        # add number to the file
        if name == old_name:
            num = num + 1
        else:
            num = 0

        old_name = name
        name = "%s.%05d" % (name, num)
        file = "%s.%05d" % (file, num)
        listfiles.append(file)

        n = 1
        # loop over images
        param_for_this_file = {}

        for i in range(nh * nw):
            if param[n]['tdir'] is not None:
                directory = param[n]['tdir']

                paramfile = os.path.join(directory, name)

                if not os.path.isfile(paramfile):
                    print("file %s does not exists..." % (paramfile))
                    sys.exit()

                param_for_this_file[n] = read_params(paramfile)

            else:		  # if no params file for this image (use .py file)
                param_for_this_file[n] = allparam[n]

            # add parameter to the list
            params[file] = param_for_this_file
            n = n + 1

    files = listfiles

    ##############################################################
    # compleete params with param
    ##############################################################

    # loop over files
    for file in files:
        n = 1
        # loop over images
        for i in range(nh * nw):

            keys = list(param[n].keys())
            for key in keys:
                params[file][n][key] = param[n][key]
            n = n + 1

    ##############################################################
    # open the movie file
    ##############################################################

    if not mkfits:

        # size of the film
        numByte = width * nh
        numLine = height * nw

        # ouverture du film
        if not info:
            f = Movie.Movie(output)

            if not complete:
                f.new(numByte, numLine)

            else:
                f.open('r+')

    ##############################################################
    # create directory for file
    ##############################################################

    else:
        if not os.path.exists(fitsdir):
            os.mkdir(fitsdir)


else:

    info = None
    pio = None
    arange = None
    redshift = None
    readonlyonce = None
    mkfits = None

    nh = None
    nw = None
    width = None
    height = None

    files = None
    params = None

#############################################################
# gather all parameters to the slaves
#############################################################

# gather options

info = mpi.mpi_bcast(info, 0)
pio = mpi.mpi_bcast(pio, 0)
arange = mpi.mpi_bcast(arange, 0)
redshift = mpi.mpi_bcast(redshift, 0)
readonlyonce = mpi.mpi_bcast(readonlyonce, 0)
mkfits = mpi.mpi_bcast(mkfits, 0)

nh = mpi.mpi_bcast(nh, 0)
nw = mpi.mpi_bcast(nw, 0)
width = mpi.mpi_bcast(width, 0)
height = mpi.mpi_bcast(height, 0)

files = mpi.mpi_bcast(files, 0)
params = mpi.mpi_bcast(params, 0)


##########################################################################
##########################################################################
##
# main loop over all files
##
##########################################################################
##########################################################################


mn_opt = 0.
mx_opt = 0.
cd_opt = 0.

cflag = 1

last_file = None

ifile = -1

# loop over all files
for file in files:

    ifile = ifile + 1

    # check if the frame already exists
    if mpi.mpi_IsMaster():
        if complete and cflag:
            data = f.read_one()

            if f.current_time < f.stoptime:
                continue
            elif f.current_time == f.stoptime:
                cflag = 0
                continue

    # loop over all images
    data = []
    for fnum in range(1, nh * nw + 1):

        ###############################
        # take parameters from params
        ###############################

        obs = params[file][fnum]['obs']
        if isinstance(obs, list):
            obs = read_ascii_value(obs, 'ArrayObs', 'obs')

        if params[file][fnum]['x0'] is None:
            x0 = None
        else:
            x0 = array(params[file][fnum]['x0'], float)
        if params[file][fnum]['xp'] is None:
            xp = None
        else:
            xp = array(params[file][fnum]['xp'], float)

        alpha = params[file][fnum]['alpha']
        view = params[file][fnum]['view']
        r_obs = params[file][fnum]['r_obs']
        clip = params[file][fnum]['clip']
        cut = params[file][fnum]['cut']
        eye = params[file][fnum]['eye']
        dist_eye = params[file][fnum]['dist_eye']
        foc = params[file][fnum]['foc']
        persp = params[file][fnum]['persp']

        shape = (width, height)
        #center =		params[file][fnum]['center']
        size = params[file][fnum]['size']
        frsp = params[file][fnum]['frsp']
        space = params[file][fnum]['space']
        mode = params[file][fnum]['mode']
        rendering = params[file][fnum]['rendering']
        filter_name = params[file][fnum]['filter_name']
        filter_opts = params[file][fnum]['filter_opts']
        scale = params[file][fnum]['scale']
        cd = params[file][fnum]['cd']
        mn = params[file][fnum]['mn']
        mx = params[file][fnum]['mx']
        l_n = params[file][fnum]['l_n']
        l_min = params[file][fnum]['l_min']
        l_max = params[file][fnum]['l_max']
        l_kx = params[file][fnum]['l_kx']
        l_ky = params[file][fnum]['l_ky']
        l_color = params[file][fnum]['l_color']
        l_crush = params[file][fnum]['l_crush']
        b_weight = params[file][fnum]['b_weight']
        b_xopts = params[file][fnum]['b_xopts']
        b_yopts = params[file][fnum]['b_yopts']
        b_color = params[file][fnum]['b_color']

        # other params
        n1 = params[file][fnum]['n1']
        n2 = params[file][fnum]['n2']
        select = params[file][fnum]['select']
        exec_param = params[file][fnum]['exec']
        macro = params[file][fnum]['macro']
        ftype = params[file][fnum]['ftype']
        time = params[file][fnum]['time']

        # multi images
        subs = params[file][fnum]['subs']

        # dir parameter
        fdir = params[file][fnum]['fdir']

        # set min and max
        if mn == 0. and mx == 0. and cd == 0. and file != files[0]:
            mn = params[files[0]][fnum]['mn']
            mx = params[files[0]][fnum]['mx']
            cd = params[files[0]][fnum]['cd']

        #
        # check if the file need to be read
        #
        must_be_read = True
        if readonlyonce:
            if last_file == os.path.splitext(file)[0]:
                must_be_read = False

        last_file = os.path.splitext(file)[0]

        if must_be_read:
            # open the model (remove the particle if there is one !!! not good
            # !!!)
            fl = os.path.splitext(file)[0]
            if pio == 'yes':
                fl = os.path.splitext(fl)[0]

            if mpi.mpi_IsMaster():
                print("reading ", fl)
            nb = Nbody(fl, ftype=ftype, pio=pio)

        ###############################
        # take parameters from dir
        ###############################
        """
    not implemented now
    """

        if redshift:
            nb.tnow = 1 / nb.tnow - 1

        if fdir is not None:
            tnow = nb.tnow

        if arange == 1:
            mn = 0.
            mx = 0.
            cd = 0.

        if view == 'dk':
            mat = zeros(width * height)
            mat = mat.astype(int8)
            mat = mat.tostring()
            data.append(mat)
        else:

            if must_be_read:

                ############
                # set time
                ############
                exec("nb.tnow = %s" % time)

                ############
                # exec
                ############
                if exec_param is not None:
                    exec(exec_param)

                ############
                # macro
                ############
                if macro is not None:
                    exec(compile(open(macro).read(), macro, 'exec'))

                ############
                # sub (n1,n2)
                ############
                if n1 is not None and n2 is not None:
                    nb = nb.sub(n1, n2)

                ############
                # select
                ############
                if select is not None:
                    nb = nb.select(select)

            if mkfits:
                mat = nb.CombiMap(obs=obs,
                                  x0=x0,
                                  xp=xp,
                                  alpha=alpha,
                                  view=view,
                                  r_obs=r_obs,
                                  eye=eye,
                                  dist_eye=dist_eye,
                                  foc=foc,
                                  mode=mode,
                                  rendering=rendering,
                                  space=space,
                                  persp=persp,
                                  clip=clip,
                                  size=size,
                                  cut=cut,
                                  frsp=frsp,
                                  shape=shape)

                if mpi.mpi_IsMaster():

                    fitsdir = 'fits'
                    output = '%04d_%04d.fits' % (ifile, fnum)
                    output = os.path.join(fitsdir, output)

                    print("%8.3f %002d %s" % (nb.tnow, fnum, output))

                    if os.path.exists(output):
                        os.remove(output)

                    print(time)
                    header = [('TIME', nb.tnow, 'snapshot time')]
                    io.WriteFits(transpose(mat), output, extraHeader=header)

            else:

                # compute map1
                mat, matint, mn_opts, mx_opts, cd_opts = nb.Map(obs=obs,
                                                                x0=x0,
                                                                xp=xp,
                                                                alpha=alpha,
                                                                view=view,
                                                                r_obs=r_obs,
                                                                eye=eye,
                                                                dist_eye=dist_eye,
                                                                foc=foc,
                                                                mode=mode,
                                                                rendering=rendering,
                                                                space=space,
                                                                persp=persp,
                                                                clip=clip,
                                                                size=size,
                                                                cut=cut,
                                                                frsp=frsp,
                                                                shape=shape,
                                                                filter_name=filter_name,
                                                                filter_opts=filter_opts,
                                                                scale=scale,
                                                                cd=cd,
                                                                mn=mn,
                                                                mx=mx,
                                                                l_color=l_color,
                                                                l_n=l_n,
                                                                l_min=l_min,
                                                                l_max=l_max,
                                                                l_kx=l_kx,
                                                                l_ky=l_ky,
                                                                l_crush=l_crush,
                                                                b_weight=b_weight,
                                                                b_xopts=b_xopts,
                                                                b_yopts=b_yopts,
                                                                b_color=b_color,
                                                                subs=subs)

                # !!!!!!!!!!!!!!!!!!!!
                # il faut recuperer les valeurs optimales
                # !!!!!!!!!!!!!!!!!!!!
                # !!! source de probleme, il faut tout garder...
                mn_opt = mn_opts[0]
                mx_opt = mx_opts[0]
                cd_opt = cd_opts[0]

                # add the image to the matrices
                matint = transpose(matint.astype(int8))
                data.append(matint.tostring())

                if mpi.mpi_IsMaster():
                    print("%8.3f %002d min=%10.3e max=%10.3e cd=%10.3e" % (nb.tnow, fnum, mn_opt, mx_opt, cd_opt))

        if file == files[0] and mn == 0. and mx == 0. and cd == 0.:
            params[file][fnum]['cd'] = cd_opt
            params[file][fnum]['mn'] = mn_opt
            params[file][fnum]['mx'] = mx_opt

    # save used parameters
    if file == files[0]:
        if mpi.mpi_IsMaster():
            write_params(SAVEDPARAMS, nh, nw, width, height, params)

    if not mkfits:

        datac = b''
        for i in range(nw):   # loop over the lines
            # append the nw pictures and sum to the one of the previous line
            datac = datac + \
                Movie.append_h(width, height, data[i * nh:(i + 1) * nh])

        if not info:
            if mpi.mpi_IsMaster():
                f.write_pic(nb.tnow, datac)


if not info:
    if mpi.mpi_IsMaster():
        f.close()
