#!/usr/bin/python3

import numpy as np
import argparse
from astropy.io import fits

import os
import matplotlib.pyplot as plt
  
####################################################################
# option parser
####################################################################

description="create a png composite rgb image from SDSS g, r, i, z bands"
epilog     ="""
Examples:
--------
mockimgs_griz_to_png imgSDSSg.fits.gz imgSDSSr.fits.gz imgSDSSi.fits.gz imgSDSSz.fits.gz
mockimgs_griz_to_png imgSDSSg.fits.gz imgSDSSr.fits.gz imgSDSSi.fits.gz imgSDSSz.fits.gz -o image.png
"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)





parser.add_argument(action="store", 
                    dest="files", 
                    metavar='FILES', 
                    type=str,
                    default=None,
                    nargs=4,
                    help='SDSS g r i and z fits images') 

parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  




def get_rgb(imgs, bands, allbands=['g','r','i', 'z'], resids=False, mnmx=None, arcsinh=None):
    '''
    Given a list of images in the given bands, returns a scaled RGB  image.

    *imgs*  a list of numpy arrays, of the same size, in nanomaggies
    *bands* a list of strings, eg, ['g','r','z']
    *mnmx*  = (min,max), values that will become black/white *after* scaling.
    Default is (-3,10)
    *arcsinh* use nonlinear scaling as in SDSS
    *scales*

    Returns a (H,W,3) numpy array with values between 0 and 1.
    '''
    allbands = list(allbands)

    # (ignore arcsinh...)
    if resids:
        mnmx = (-0.1, 0.1)
    if mnmx is not None:
        return sdss_rgb(imgs, bands, m=0., Q=None, mnmx=mnmx)
    return sdss_rgb(imgs, bands)


def sdss_rgb(imgs, bands, scales=None, m=0.03, Q=20, mnmx=None, clip=True):

    rgb_stretch_factor =1.0
    rgbscales=dict( # 6.0 3.4 2.2
                   g =    (2, 6.0 * rgb_stretch_factor),
                   r =    (1, 3.4 * rgb_stretch_factor),
                   i =    (0, 3.0 * rgb_stretch_factor),
                   z =    (0, 2.2* rgb_stretch_factor),
                   
     )


    if scales is not None:
        rgbscales.update(scales)

    I = 0
    for img,band in zip(imgs, bands):
        _,scale = rgbscales[band]
        img = np.maximum(0, img * scale + m)
        I = I + img
    I /= len(bands)
    if Q is not None:
        fI = np.arcsinh(Q * I) / np.sqrt(Q)
        I += (I == 0.) * 1e-6
        I = fI / I
    H,W = I.shape
    rgb = np.zeros((H,W,3), np.float32)

    if bands == ['g','r','i','z']:

        rgbvec = dict(
            g = (0.,   0.,  0.75),
            r = (0.,   0.5, 0.25),
            i = (0.25, 0.5, 0.),
            z = (0.75, 0.,  0.))

        for img,band in zip(imgs, bands):
            _,scale = rgbscales[band]
            rf,gf,bf = rgbvec[band]
            if mnmx is None:
                v = (img * scale + m) * I
            else:
                mn,mx = mnmx
                v = ((img * scale + m) - mn) / (mx - mn)
            if clip:
                v = np.clip(v, 0, 1)
            if rf != 0.:
                rgb[:,:,0] += rf*v
            if gf != 0.:
                rgb[:,:,1] += gf*v
            if bf != 0.:
                rgb[:,:,2] += bf*v
    else:
        for img,band in zip(imgs, bands):
            plane,scale = rgbscales[band]
            if mnmx is None:
                imgplane = (img * scale + m) * I
            else:
                mn,mx = mnmx
                imgplane = ((img * scale + m) - mn) / (mx - mn)
            if clip:
                imgplane = np.clip(imgplane, 0, 1)
            rgb[:,:,plane] = imgplane
    return rgb
                                
                                        
                    
####################################################################
# main
####################################################################


if __name__ == '__main__':
  
  opt = parser.parse_args()
  
  
  field3b_gdat,field3b_ghdr=fits.getdata(opt.files[0],header=True)
  field3b_rdat,field3b_rhdr=fits.getdata(opt.files[1],header=True)
  field3b_idat,field3b_zhdr=fits.getdata(opt.files[2],header=True)
  field3b_zdat,field3b_zhdr=fits.getdata(opt.files[3],header=True)

  field3b_gdat = 10**(-0.4*(field3b_gdat -22.5))
  field3b_rdat = 10**(-0.4*(field3b_rdat -22.5))
  field3b_idat = 10**(-0.4*(field3b_idat -22.5))
  field3b_zdat = 10**(-0.4*(field3b_zdat -22.5))

  color_img=get_rgb([field3b_gdat ,field3b_rdat,field3b_idat,field3b_zdat], ['g','r','i','z'])




  params = {
    "axes.labelsize": 14,
    "axes.titlesize": 18,
    "font.size": 12,
    "legend.fontsize": 12,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "text.usetex": False,
    "figure.subplot.left": 0.0,
    "figure.subplot.right": 1.0,
    "figure.subplot.bottom": 0.0,
    "figure.subplot.top": 1.0,
    "figure.subplot.wspace": 0.02,
    "figure.subplot.hspace": 0.02,
    "figure.figsize" : (12, 12),
    "lines.markersize": 6,
    "lines.linewidth": 3.0,
  }
  plt.rcParams.update(params)






  plt.imshow(color_img)

  if opt.outputfilename:
    plt.savefig(opt.outputfilename)
  else:
    plt.show()




