#!/usr/bin/env python
#
# Generate and Image from the RESTful the API on a TART radio telescope
# Tim Molteno 2017 tim@elec.ac.nz
#
import matplotlib
import os
if os.name == 'posix' and "DISPLAY" not in os.environ:
    matplotlib.use('Agg')
import matplotlib.pyplot as plt

import logging
import argparse
import sys
import threading
import datetime
import json

import numpy as np

from tart_tools import api_imaging
from tart_tools import api_handler
from tart.operation import settings
from tart.imaging import elaz

from copy import deepcopy

logger = logging.getLogger()

def handle_image(args, img, n_bin, title, time_repr, source_json=None):
    """ This function manages the output of an image, drawing sources e.t.c."""
    image_title = '{}_{}'.format(title, time_repr)
    if args.fits:
        fname = '{}.fits'.format(image_title)
        api_imaging.save_fits_image(img, fname=fname, out_dir=args.dir, timestamp=time_repr)
        print("Generating {}".format(fname))
    if args.PNG or args.display:
        api_imaging.make_image(plt, img, image_title, n_bin, source_json, args.healpix)
    if args.PNG:
        fname = '{}.png'.format(image_title)
        plt.savefig(os.path.join(args.dir, fname))
        print("Generating {}".format(fname))
    if args.display:
        plt.show()

if __name__ == '__main__':
    PARSER = argparse.ArgumentParser(description='Generate an image using the web api ofs a TART radio telescope.', 
                                    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    PARSER.add_argument('--api', required=False, default='https://tart.elec.ac.nz/signal', help="Telescope API server URL.")
    PARSER.add_argument('--catalog', required=False, default='https://tart.elec.ac.nz/catalog', help="Catalog API URL.")
    PARSER.add_argument('--file', required=False, default=None, help="Get data from a calibration JSON file.")

    PARSER.add_argument('--gains', required=False, default=None, help="Use a local JSON file containing antenna gains to create the image image.")
    PARSER.add_argument('--vis', required=False, default=None, help="Use a local JSON file containing the visibilities to create the image image.")
    PARSER.add_argument('--dir', required=False, default='.', help="Output directory.")
    PARSER.add_argument('--rotation', type=float, default=0.0, help="Apply rotation (in degrees) to the antenna positions.")
    PARSER.add_argument('--nfft', type=int, default=10, help="Log(2) of the number of points in the fft.")

    PARSER.add_argument('--dirty', action="store_true", help="Create a direct IFFT dirty image.")
    PARSER.add_argument('--difmap', action="store_true", help="Use difmap to generate a CLEAN image (requires extenal difmap executable).")
    PARSER.add_argument('--aipy', action="store_true", help="Use AIPY to generate a CLEAN image.")
    PARSER.add_argument('--moresane', action="store_true", help="Use MORESANE to generate a wavelet based compressed representation (requires extenal difmap executable).")
    PARSER.add_argument('--beam', action="store_true", help="Generate beam image")

    PARSER.add_argument('--healpix', action="store_true", help="Use HealPIX to map the resulting image onto a projected circle.")

    PARSER.add_argument('--display', action="store_true", help="Display Image to the user")
    PARSER.add_argument('--log', action="store_true", help="Output the log of the image")
    PARSER.add_argument('--fits', action="store_true", help="Generate a FITS format image")
    PARSER.add_argument('--PNG', action="store_true", help="Generate a PNG format image")
    PARSER.add_argument('--show-sources', action="store_true", help="Show known sources on images (only works on PNG).")

    source_json = None

    ARGS = PARSER.parse_args()

    logger.setLevel(logging.DEBUG)
    # create console handler and set level to debug
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    # create formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # add formatter to ch
    ch.setFormatter(formatter)
    # add ch to logger
    logger.addHandler(ch)

    if ARGS.file:
        logger.info("Getting Data from file: {}".format(ARGS.file))
        # Load data from a JSON file
        with open(ARGS.file, 'r') as json_file:
            calib_info = json.load(json_file)

        info = calib_info['info']
        ant_pos = calib_info['ant_pos']
        config = settings.from_api_json(info['info'], ant_pos)

        flag_list = [] # [4, 5, 14, 22]

        original_positions = deepcopy(config.get_antenna_positions())

        gains_json = calib_info['gains']
        gains = np.asarray(gains_json['gain'])
        phase_offsets = np.asarray(gains_json['phase_offset'])
        config = settings.from_api_json(info['info'], ant_pos)
    
        measurements = []
        for d in calib_info['data']:
            vis_json, source_json = d
            cv, timestamp = api_imaging.vis_calibrated(vis_json, config, gains, phase_offsets, flag_list)
            src_list = elaz.from_json(source_json, 0.0)

    else:
        logger.info("Getting Data from API: {}".format(ARGS.api))

        api = api_handler.APIhandler(ARGS.api)
        config = api_handler.get_config(api)

        if (ARGS.gains is None):
            gains = api.get('calibration/gain')
        else:
            with open(ARGS.gains, 'r') as json_file:
                gains = json.load(json_file)

        if (ARGS.vis is None):
            vis_json = api.get('imaging/vis')
        else:
            with open(ARGS.vis, 'r') as json_file:
                vis_json = json.load(json_file)

        ts = api_imaging.vis_json_timestamp(vis_json)
        if ARGS.show_sources:
            source_json = api.get_url(api.catalog_url(config, datestr=ts.isoformat()))

        logger.info("Data Download Complete")

        cv, timestamp = api_imaging.vis_calibrated(vis_json, config, gains['gain'], gains['phase_offset'], flag_list=[])


    api_imaging.rotate_vis(ARGS.rotation, cv, reference_positions = deepcopy(config.get_antenna_positions()))
    n_bin = 2**ARGS.nfft

    time_repr = "{:%Y_%m_%d_%H_%M_%S_%Z}".format(timestamp)

    # Image informstion
    # prihdr.set('DATE', timestamp)
    # prihdr.set('DATE-OBS', timestamp)
    # 
    # prihdr.set('TIMESYS', 'UTC')
    # prihdr.set('INSTRUME', 'TART')
    # prihdr.set('TELESCOP', 'TART')
    # prihdr.set('OBSERVER' 'CASA simulator')
    # prihdr.set('ORIGIN', 'tart_tools tart.elec.ac.nz ')
    # 
    # prihdr.set('OBSRA', 2.889721000000E+02)
    # prihdr.set('OBSDEC', -7.466052777778E+01)
    # prihdr.set('OBSGEO-X', 5.111202828133E+06)
    # prihdr.set('OBSGEO-Y', 2.001309252764E+06)
    # prihdr.set('OBSGEO-Z', -3.237339358474E+06)


    # Processing

    if ARGS.dirty or ARGS.moresane or ARGS.aipy:
        cal_ift, cal_extent, n_fft, bin_width = api_imaging.image_from_calibrated_vis(cv, nw=n_bin/4, num_bin=n_bin)

        img = np.abs(cal_ift)
        sd = np.std(img)
        ift_scaled = img/sd

    if ARGS.beam or ARGS.moresane or ARGS.aipy:
        beam = np.abs(api_imaging.beam_from_calibrated_vis(cv, nw=n_bin/4, num_bin=n_bin))

    if ARGS.difmap:
        fits_bin = 2**12
        
        fitsfile =  "{}.uvfits".format(time_repr)
        uvfitspath = os.path.join(ARGS.dir, fitsfile)
        fits_gen = api_imaging.get_uv_fits(cv)
        fits_gen.write(uvfitspath)

        commands = """
! Basic imaging instructions by Tim Molteno, based around the final_clean script
! by Dan Homan,  dhoman@nrao.edu
debug = False
observe {}
mapsize {}, {}
docont=false
mapcolor color
integer clean_niter; 
float clean_gain; clean_gain = 0.03
float dynam;
float flux_peak;

! Define the inner loop as a macro.

float flux_cutoff
float dyn_range
float last_in_rms
float in_rms
float target_rms
float V_rms

#+map_residual \
flux_peak = peak(flux);\
flux_cutoff = imstat(rms) * dynam;\
while(abs(flux_peak)>flux_cutoff);\
 clean clean_niter,clean_gain;\
 flux_cutoff = imstat(rms) * dynam;\
 flux_peak = peak(flux);\
end while

! The following macro stops
! when the the in_frame RMS matches the
! the V RMS OR if there is not improvement 
! in the in_frame RMS

#+deep_map_residual \
in_rms = imstat(rms);\
print "Target RMS: ", target_rms, "  In Frame RMS: ", in_rms;\
while(in_rms > target_rms);\
 clean min(200*(in_rms/target_rms),1000),clean_gain;\
 last_in_rms = in_rms;\
 in_rms = imstat(rms);\
 print "Target RMS: ", target_rms, "  In Frame RMS: ", in_rms;\
 if(last_in_rms <= in_rms);\
  in_rms = target_rms;\
 end if;\
 selfcal;\
end while


! select V and get the V_rms for comparison
select v
clrmod true,true,true
delwin
uvw 0,-2
uvtaper 0
V_rms = imstat(rms)

! select the stokes to clean
select i

! clear previous model
clrmod true,true,true

! delete any windows
delwin

! remove any tapering
uvtaper 0

print "*********** FIRST TRY SUPER-UNIFORM WEIGHTING **********"
print "**** -- only if dynamic range is higher than 10 -- *****"

dynam = 10
clean_niter = 50
uvw 20,-1
map_residual
uvw 10,-1
map_residual
clean_niter = 50

print "*********** REGULAR UNIFORM WEIGHTING NEXT ***************"

uvw 2,-1
dynam = 6
map_residual
print "********** DEEP CLEANING AT NATURAL WEIGHTING **************"
uvw 0,-2
! now let clean go deep
target_rms = imstat(noise)/8
if(target_rms < imstat(rms))
  deep_map_residual
else
  ! clean 1 component just to have something to restore
  clean 1, clean_gain
end if

in_rms = imstat(rms)
print "********** FINAL CLEAN IS FINISHED **************"
print "Target RMS was: ", target_rms, " Reached RMS: ", in_rms
print "For comparison uncleaned V RMS is: ", V_rms
print "*************************************************"

device {}_difmap.png/png
mappl cln
!wmap {}_clean.fits
exit
""".format(uvfitspath, fits_bin, 1.3*648000000.0/fits_bin, uvfitspath, uvfitspath)
        f = open('difmap_cmds', 'w')
        f.write(commands)
        f.close()
        os.system("/usr/local/bin/difmap < difmap_cmds")
        #os.system("vips scale {}_clean.fits {}_grey.png".format(uvfitsfile, uvfitsfile))

    if ARGS.moresane:
        from pymoresane.main import DataImage

        degrees_per_pixel = 180.0 / len(beam)

        sane = DataImage(img/np.sum(beam), beam, mask_data=None, cdelt1=degrees_per_pixel, cdelt2=degrees_per_pixel)
        logger = sane.make_logger('INFO')
        
        beam_width_degrees = 3.0
        stop_scale = ARGS.nfft - 1  # Maximum size feature (full image width)
        start_scale = int(np.log2(n_bin*beam_width_degrees/180.0))
        
        print("Start Scale {}".format(start_scale))
        sane.moresane(subregion=None, sigma_level=4.0, loop_gain=0.05, tolerance=0.75, accuracy=1e-6,
                major_loop_miter=100, minor_loop_miter=50, all_on_gpu=False, decom_mode="ser", core_count=1,
                conv_device='cpu', conv_mode='circular', extraction_mode='cpu', enforce_positivity=True,
                edge_suppression=True, edge_offset=int(n_bin/30), flux_threshold=0,
                neg_comp=False, edge_excl=0, int_excl=0)
        sane.restore()

    if ARGS.aipy:
        import aipy
        clean, info = aipy.deconv.clean(img, beam, gain=.1, maxiter=1000, tol=1e-5, verbose=True, stop_if_div=True, pos_def=False)

        if True:
            b_restore = np.fft.fftshift(beam)
        else:
            #b_restore = a.img.gaussian_beam(3, shape=img.shape) #* np.max(beam)
            b_center = len(beam)/2
            b_width_px = int(3.0 * n_bin/180)
            b_restore_shifted = np.zeros_like(beam)

            y,x = np.ogrid[-b_center:b_center, -b_center:b_center]
            mask = x*x + y*y <= b_width_px**2

            b_restore_shifted[mask] = np.fft.fftshift(beam)[mask]
            b_restore = np.fft.fftshift(b_restore_shifted)

        restored_image = np.abs(np.fft.ifft2(np.fft.fft2(clean) * np.fft.fft2(b_restore))) # Cropped PSF Restoration

        print('Clean', info['success'], np.min(restored_image), np.max(restored_image))


    # Do output images

    if ARGS.moresane:
        handle_image(ARGS, sane.restored, n_bin, "MORESANE", time_repr, source_json)

    if ARGS.beam:
        if ARGS.log:
            beam = np.log10(beam)
        handle_image(ARGS, beam, n_bin, "beam", time_repr)

    if ARGS.dirty:
        handle_image(ARGS, ift_scaled, n_bin, "dirty", time_repr, source_json)

    if ARGS.aipy:
        handle_image(ARGS, restored_image, n_bin, "clean", time_repr, source_json)

