#!/usr/bin/env python3
###########################################################################################
#  package:   pNbody
#  file:      pnbmov_fits2pngs
#  copyright: GPLv3
#             Copyright (C) 2019 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Yves Revaz <yves.revaz@epfl.ch>
#
# This file is part of pNbody.
###########################################################################################

import Mtools as mt
from Mtools import pyfits
from PIL import Image

import numpy as np
import argparse
import os
import sys
import glob
import copy

from pNbody.palette import Palette
from pNbody import apply_filter

########################################
#
# parser
#
########################################

parser = argparse.ArgumentParser()


parser.add_argument("-d", "--dir",
                    action="store",
                    dest="dir",
                    type=str,
                    default=None,
                    metavar='DIRECTORY',
                    help="output directory for png files")


parser.add_argument('files',
                    metavar='FILES',
                    type=str,
                    nargs='+',
                    help='list of files')


parser.add_argument("-m", "--mode",
                    action="store",
                    dest="mode",
                    type=str,
                    default="generic",
                    help="combination mode")


parser.add_argument("--imode",
                    action="store",
                    dest="imode",
                    type=str,
                    default='RGB',
                    help="image mode : L, P RGB or RGBA")


parser.add_argument("--palette",
                    action="store",
                    dest="palette",
                    type=str,
                    default='light',
                    help="palette",
                    metavar=" NAME")

parser.add_argument("--scale",
                    action="store",
                    dest="scale",
                    type=str,
                    default='log',
                    help="scale",
                    metavar=" STRING")

parser.add_argument("--mn",
                    action="store",
                    dest="mn",
                    type=float,
                    default=0.0,
                    help="min value",
                    metavar=" FLOAT")

parser.add_argument("--mx",
                    action="store",
                    dest="mx",
                    type=float,
                    default=0.0,
                    help="max value",
                    metavar=" FLOAT")

parser.add_argument("--cd",
                    action="store",
                    dest="cd",
                    type=float,
                    default=0.0,
                    help="cd value",
                    metavar=" FLOAT")


parser.add_argument("--cfact",
                    action="store",
                    dest="cfact",
                    type=float,
                    default=1.0,
                    help="cfact value",
                    metavar=" FLOAT")


parser.add_argument("--params",
                    action="store",
                    dest="params",
                    type=str,
                    default=None,
                    help="parameter file",
                    metavar=" STRING")


parser.add_argument("-o",
                    action="store",
                    dest="outputfile",
                    type=str,
                    default=None,
                    help="outputfile",
                    metavar=" FILE")


opt = parser.parse_args()


################################################################################
#
#                                    MAIN
#
################################################################################

if opt.dir != None:
    if not os.path.exists(opt.dir):
        os.mkdir(opt.dir)


########################
# create list of files


flists = []

for fname in opt.files:

    files = glob.glob(fname)
    files.sort()
    flists.append(files)

if opt.mode == 'palette':

    if len(flists) > 1:
        print("WARNING : more than one type of fits file, using 0 only")

    n = len(flists[0])


elif opt.mode == 'rgb':

    if len(flists) != 3:
        print("ERROR : rgb mode needs at leas 3 list of fits files")
        sys.exit()
    else:
        n1 = len(flists[0])
        n2 = len(flists[1])
        n3 = len(flists[2])
        n = min(n1, n2, n3)

else:
    n = len(flists[0])
    for i in range(len(flists)):
        n = min(n, len(flists[i]))

    if opt.params == None:
        print("ERROR : default mode needs to specify a parameter file")
        sys.exit()

# some info
print()
print("mode %s" % opt.mode)
print("%d files found" % n)
print()


def ctrans(mat, tg, bg, a3):
    """
    This function convert a physical value between [0-1]
    into a palette [0-255].
    Paramters are:
    The target color value (tg), i.e. the value at 1.
    The background value (bg), i.e. the value at 0.
    a3 is an elbow paramter.
    a3<<1 -> linear fit
    a3>1 -> exponential fit
    """

    if tg == bg:
        return bg*np.ones(mat.shape)

    a1 = (tg-bg)/((np.exp(-1.*a3)-1.))
    a2 = bg-a1
    return a1*np.exp(-mat*a3) + a2


###################################
# loop over all files
###################################


for i in range(n):

    mats = []

    file = flists[0][i]

    if opt.mode == 'palette':
        data = pyfits.open(flists[0][i])[0].data
        data = np.transpose(data)
        img = mt.fits_apply_palette(
            data, scale=opt.scale, cd=opt.cd, mn=opt.mn, mx=opt.mx, palette=opt.palette)

    elif opt.mode == 'rgb':
        img = mt.fits_compose_rgb_img(
            [flists[0][i], flists[1][i], flists[2][i]], opt.cfact)

    else:

        nc = len(flists)

        args = []
        imgs = []

        # first initializes arg for each image

        for j in range(nc):

            arg = {}

            arg['cd'] = None
            arg['mn'] = None
            arg['mx'] = None
            arg['palette'] = None
            arg['scale'] = 'lin'

            arg['ar'] = None
            arg['ag'] = None
            arg['ab'] = None
            arg['aa'] = 0.0

            arg['fc'] = 1.
            arg['bg'] = 0        # background [0-255] 0=black 255=white
            arg['a3'] = 0.01     # linear fit

            # slope that determined the alpha dependence on s2 (signal of the 2e image)
            arg['qs2'] = 0.1
            # positive value : the image will dominate over the background
            # negative value : the background will dominates

            arg['smode'] = "alpha"  # supperposition mode (see code)

            arg['filter_name'] = None
            arg['filter_opts'] = None

            args.append(arg)

        # read the parameters
        # execfile(opt.params)
        exec(open(opt.params).read())

        # loop over files
        for j in range(nc):

            # open the model
            data = pyfits.open(flists[j][i])[0].data

            # apply filter
            if args[j]['filter_name'] != None:
                data = apply_filter(
                    data, name=args[j]['filter_name'], opt=args[j]['filter_opts'])

            #########################################
            # 1) transform to 0-1
            #########################################
            mat, tmp_mn, tmp_mx, tmp_cd = mt.normalize(
                data, scale=args[j]["scale"], cd=args[j]["cd"], mn=args[j]["mn"], mx=args[j]["mx"])

            # scale if needed
            mat = mat*args[j]['fc']

            mats.append(mat)

            print("(%d) mn=%g mx=%g cd=%g" % (j, tmp_mn, tmp_mx, tmp_cd))

            #########################################
            # 2) transform matrix to color
            #########################################

            # a) create an image using a palette
            if args[j]['palette'] != None:
                mata = 255*mat
                matint = mata.astype(np.int8)

                img = Image.frombytes(
                    "P", (matint.shape[1], matint.shape[0]), matint.tostring())

                palette = Palette(args[j]['palette'])
                img.putpalette(palette.palette)
                img = img.convert('RGB')

            # b) create an image using rgb coefficients
            elif args[j]['ar'] != None and args[j]['ar'] != None and args[j]['ab'] != None:

                # explonential dependence (for c small, -> linear)

                bg = float(args[j]['bg'])
                a3 = float(args[j]['a3'])

                r = ctrans(mat, args[j]['ar'], bg, a3)
                g = ctrans(mat, args[j]['ag'], bg, a3)
                b = ctrans(mat, args[j]['ab'], bg, a3)

                '''
        from matplotlib import pyplot as plt
        x = np.arange(0,1,0.01)

        print(j,args[j]['ar'],args[j]['ag'],args[j]['ab'])

        y = ctrans(x,args[j]['ar'],bg,a3)
        plt.scatter(x,y)
        plt.show()

        y = ctrans(x,args[j]['ag'],bg,a3)
        plt.scatter(x,y)
        plt.show()

        y = ctrans(x,args[j]['ab'],bg,a3)
        plt.scatter(x,y)
        plt.show()

        print("min",r.ravel().min())
        print(r.ravel().min(),g.ravel().min(),b.ravel().min())
        '''

                # create image and save it
                r = np.uint8(np.clip(r, 0, 255))
                g = np.uint8(np.clip(g, 0, 255))
                b = np.uint8(np.clip(b, 0, 255))

                size = (mat.shape[1], mat.shape[0])

                image_r = Image.fromstring("L", size, r)
                image_g = Image.fromstring("L", size, g)
                image_b = Image.fromstring("L", size, b)

                img = Image.merge('RGB', (image_r, image_g, image_b))

            size = img.size
            imgs.append(img)

            # temp. save
            img.save("%d.png" % j)

        #########################################
        # 4) compose images
        #########################################

        for j, im in enumerate(imgs):

            # Current image
            r, g, b = im.split()
            r = np.array(r.getdata())
            g = np.array(g.getdata())
            b = np.array(b.getdata())

            if j == 0:
                # First image (background)
                rt = copy.copy(r)
                gt = copy.copy(g)
                bt = copy.copy(b)
                s1 = np.ravel(mats[j])
            else:

                s2 = np.ravel(mats[j])

                # add only if the signal is non zero
                if s2.max() > 0:

                    if args[j]["smode"] == "alpha_gamma":               # normal alpha channel

                        q = args[j]["qs2"]

                        # Values bigger than 1 lighten the image, values lower than 1 darken the image. gamma = 1 corrsponds to alpha mode
                        gamma = args[j]["gamma"]
                        s2p = (np.exp(-s2*q) - np.exp(-q)) / (1. - np.exp(-q))

                        a1 = 1                                      # force first image
                        a2 = 1-s2p   # ~ s2                         # second image is

                        norm = a2+a1*(1-a2) + 1e-5

                        rt = ((a2*r**gamma + a1*(1.0-a2)*rt**gamma)/norm)**gamma
                        gt = ((a2*g**gamma + a1*(1.0-a2)*gt**gamma)/norm)**gamma
                        bt = ((a2*b**gamma + a1*(1.0-a2)*bt**gamma)/norm)**gamma

                    elif args[j]["smode"] == "alpha":               # normal alpha channel

                        q = args[j]["qs2"]
                        s2p = (np.exp(-s2*q) - np.exp(-q)) / (1. - np.exp(-q))

                        a1 = 1                                      # force first image
                        a2 = 1-s2p   # ~ s2                         # second image is

                        norm = a2+a1*(1-a2) + 1e-5

                        rt = (a2*r + a1*rt*(1.0-a2))/norm
                        gt = (a2*g + a1*gt*(1.0-a2))/norm
                        bt = (a2*b + a1*bt*(1.0-a2))/norm

                    # add the second image on top of
                    elif args[j]["smode"] == "add":
                        # the previous one which is first
                        # damped where the signal of the second
                        # is strong

                        q = args[j]["qs2"]
                        s2p = (np.exp(-s2*q) - np.exp(-q)) / (1. - np.exp(-q))

                        a1 = s2p    # ~ 1-s2
                        a2 = 1.0

                        rt = a1*rt + a2*r
                        gt = a1*gt + a2*g
                        bt = a1*bt + a2*b

                    else:

                        print("uknown smod %s !" % args[j]["smode"])
                        print("Aborting !")
                        sys.exit()

        rt = np.uint8(np.clip(rt, 0, 255))
        gt = np.uint8(np.clip(gt, 0, 255))
        bt = np.uint8(np.clip(bt, 0, 255))

        image_r = Image.frombytes("L", size, rt)
        image_g = Image.frombytes("L", size, gt)
        image_b = Image.frombytes("L", size, bt)
        img = Image.merge('RGB', (image_r, image_g, image_b))

    # mode
    if opt.imode != None:
        img = img.convert(opt.imode)

    #######################
    # write
    #######################

    if opt.outputfile:
        fout = opt.outputfile
    else:
        fout = os.path.join(opt.dir, '%08d.png' % (i))

    img.save(fout)

    # info
    print(file)

    nums = "%8d/%8d" % (i, n)
    print("%s    --> %s" % (nums, fout))
