#!/usr/bin/env python3

import argparse
from astropy.io import fits
import os
import numpy as np
from pNbody.Mockimgs import noise
  
####################################################################
# option parser
####################################################################

description="add noise on an sb image"
epilog     ="""
Examples:
--------
mockimgs_sb_add_noise in.fits  -o out.fits
mockimgs_sb_add_noise in.fits --SB_limit 31 --output-flux -o out.fits

"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)





parser.add_argument(action="store", 
                    dest="file", 
                    metavar='FILE', 
                    type=str,
                    default=None,
                    help='input file') 

parser.add_argument("--SB_limit",
                    action="store", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='surface brightness limit') 

parser.add_argument("--SB_area",
                    action="store", 
                    metavar='FLOAT', 
                    type=float,
                    default=100,
                    help='area in arcsec2 to compute the SB limit') 
                    
parser.add_argument("--SB_SN",
                    action="store", 
                    metavar='FLOAT', 
                    type=float,
                    default=3,
                    help='targeted signal to noise') 

parser.add_argument("--sigma_per_pixel",
                    action="store", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='standard noise deviation per pixel') 


parser.add_argument("--random-seed",
                    action="store", 
                    dest="random_seed",
                    metavar='INT', 
                    type=int,
                    default=0,
                    help='random seed')
                                                            
                    
parser.add_argument("--output-flux",
                    action="store_true", 
                    dest="output_flux",
                    default=False,
                    help='store flux assuming')

parser.add_argument("--noise-only",
                    action="store_true", 
                    dest="noise_only",
                    default=False,
                    help='consider only noise')
                    

parser.add_argument("-v",
                    action="store_true", 
                    dest="verbose",
                    default=False,
                    help='verbose mode')
                    
                    
parser.add_argument("--zp",
                    action="store", 
                    metavar='FLOAT', 
                    type=float,
                    default=0,
                    help='zero point to get flux') 
                                        
                     
                    
parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  



                                        
                    
####################################################################
# main
####################################################################

compress=True

if __name__ == '__main__':
  
  opt = parser.parse_args()
    
  filename = opt.file
    
  if opt.verbose:
    print(filename)
  
  hdul = fits.open(filename)
  
  data = hdul[0].data          # get the surface brightness map
  header = hdul[0].header 
  
  # get the pixel size
  if "PIXFOVX" in header:
    PIXFOVX = header["PIXFOVX"]
  else:
    raise ValueError("'PIXFOVX' not found in the fits header.")
  
  # get the image shape
  if "NAXIS1" in header:
    NAXIS1 = header["NAXIS1"]
  else:
    raise ValueError("'NAXIS1' not found in the fits header.")    
  
  if "NAXIS2" in header:
    NAXIS2 = header["NAXIS2"]
  else:
    raise ValueError("'NAXIS2' not found in the fits header.")    
    
  if "PIXAREA" in header:
    Omega = header["PIXAREA"]
  else:
    raise ValueError("'PIXAREA' not found in the fits header.")        
    


  # check if pixels contains magnitudes of fluxes
  # get the image shape
  if "UNITS" in header:
    units = header["UNITS"]
  else:
    raise ValueError("'UNITS' not found in the fits header.")    
  
  # the image contains magnitudes
  if units=='mag/arcsec^2':
    # get the flux
    F = Omega*10**(-(data-opt.zp)/2.5)
  
  # the image contains a flux
  else:
    F = data


  if opt.SB_limit is not None:
    # get the std to apply on a pixel basis    
    sigma = noise.get_std_for_SB_limit(opt.SB_limit,PIXFOVX,area=opt.SB_area,sn=opt.SB_SN)
  
  elif opt.sigma_per_pixel is not None:
    sigma = opt.sigma_per_pixel
  
  else:
    raise(ValueError,"SB_limit or sigma_per_pixel must be provided.")
  
  
  if opt.verbose:
    print("sigma per pixel: ",sigma)
    
  
  # generate a noise map
  np.random.seed(opt.random_seed)
  noise_map = np.random.normal(loc=0.0, scale=sigma, size=(NAXIS1,NAXIS2))
  
  # add to the flux
  if opt.noise_only:
    F = noise_map
  else:
    F = F + noise_map
  
  if not opt.output_flux:
    # clean negative values
    F = np.where(F<0, 1e-40, F)
    # back to SB magnitudes
    data = -2.5*np.log10(F/Omega) + opt.zp    
  else:
    # wee keep the flux
    data = F
          
  # fits stuffs
  hdu = fits.PrimaryHDU(data)
  hdu.header = header
  
  # add keywords
  script_name = os.path.basename(__file__)
  hdu.header['NOISESTD']  = (sigma,        "%s: noise std"%script_name)
  hdu.header['SB_LIMIT']  = (opt.SB_limit, "%s: computed SB limit"%script_name)
  hdu.header['SB_AREA']   = (opt.SB_area,  "%s: area considered"%script_name)
  hdu.header['SB_SN']     = (opt.SB_SN,    "%s: SN considered"%script_name)
  
  
  if opt.output_flux: 
    hdu.header['UNITS']     = ('undefined flux','pixel unit')
    hdu.header['ZP']        = (opt.zp,'zero point used')
  
  
  
  # save
  if opt.outputfilename:

    if os.path.isfile(opt.outputfilename):
      os.remove(opt.outputfilename)
  
    hdu.writeto(opt.outputfilename)
    
    if compress:
      import gzip
      import shutil
  
      with open(opt.outputfilename, 'rb') as f_in:
        with gzip.open("%s.gz"%opt.outputfilename, 'wb') as f_out:
          shutil.copyfileobj(f_in, f_out)
  
      os.remove(opt.outputfilename)
  
    
    



