#!/usr/bin/env python

#
# Configurable options
#

energy_column = 0  # (subtract 1 from the PNC file column number!)
i0_column = 3

high_filter = 5000 # set this to None to disable

bad_pixels = []    
#bad_pixels = [ (14, 185) ] # PNC Pilatus

# load up list of bad pixels (probably not needed if high_filter is set appropriately)
#bp = np.loadtxt('../bp_dp_pilatus_2.dat')
#bad_pixels = [(x,y) for x,y in bp]

# list of bad columns to skip
#skip_columns = [220, 221, 233, 234] #(DP-PILATUS2)
skip_columns = []

#
# Everything below here should generally not need to be changed
#

import minixs as mx
import numpy as np
import sys, os

usage = "Usage: %s <calibration> --scans <scan files> --exposures <exposure files> --outfile <output filename> --i0column <io column>" % sys.argv[0]


def parse_args():
  # read command line options
  scans = []
  exposures = []
  outputs = []
  i0col = []
  cur = scans
  calib_file = sys.argv[1]
  for arg in sys.argv[2:]:
    if arg == '-s' or arg == '--scans':
      cur = scans
    elif arg == '-e' or arg == '--exposures':
      cur = exposures
    elif arg == '-o' or arg == '--outfile':
      cur = outputs
    elif arg == '--i0column':
      cur = i0col
    else:
      cur.append(arg)

  if len(outputs) == 0:
    print ("Please specify an output filename.")
    print (usage)
    exit
  elif len(i0col) == 0:
    print ("Please specify a column for I0 in the scan file.")
    print (usage)
    exit


  return calib_file, scans, exposures, outputs[0], i0col[0]

def progress_cb(i, energy):
  sys.stdout.write(".")
  sys.stdout.flush()

if __name__ == "__main__":
  if len(sys.argv) < 4:
    print (usage)
    exit()

  calibration_file, scans, exposures, outfile, i0colstr = parse_args()
  i0_column = int(i0colstr)
  #print (i0_column)

  if len(scans) < 1 or len(exposures) < 1:
    print (usage)
    exit()

  import platform
#   if platform.system() == 'Windows':
  import glob, operator
  sorted_glob = lambda x: sorted(glob.glob(x))
  exposures = reduce(operator.add, map(sorted_glob, exposures))
  scans = reduce(operator.add, map(glob.glob, scans))

  # read in scan files
  energies = None
  I0s = None

  for scan in scans:
    e,i = mx.misc.read_scan_info(scan, [energy_column, i0_column])

    if energies is None:
      energies = e
    else:
      print (len(energies), len(e))
      diff = np.max(np.abs(energies - e))
      if diff > 0.25:
        print ("Energies from scan %s differ from others. Aborting." % scan)
        exit()
    
    if I0s is None:
      I0s = i
    else:
      I0s += i

  # ensure that # of exposures is correct for number of entries in scan
  if len(exposures) % len(energies) is not 0:
    print ("Number of energies does not divide number of exposures. Aborting.")
    exit()

  calib = mx.calibrate.load(calibration_file)

  # setup bad pixel filter if needed
  filters = []
  if bad_pixels:
    fltr = mx.filter.BadPixelFilter()
    if calib.dispersive_direction in [mx.UP, mx.DOWN]:
      mode = fltr.MODE_INTERP_V
    else:
      mode = fltr.MODE_INTERP_H
    fltr.set_val((mode, bad_pixels))
    filters.append(fltr)

  if high_filter is not None:
    fltr = mx.filter.HighFilter()
    fltr.set_val(high_filter)
    filters.append(fltr)

  rixs = mx.rixs.RIXS()
  rixs.exposure_files = [os.path.abspath(f) for f in exposures]
  rixs.calibration_file = os.path.abspath(calibration_file)
  rixs.energies = energies
  rixs.I0s = I0s
  rixs.filters = filters
  print ("rixs.exposure_files %s" % (rixs.exposure_files,))
  print ("rixs.calibration_file %s" % rixs.calibration_file)
  print ("rixs.energies %s" % rixs.energies)
  print ("rixs.IOs %s" % rixs.IOs)
  print ("rixs.filters %s" % rixs.filters)

  c = calib.calibration_matrix
  emission_energies = np.arange(c[np.where(c>0)].min(), c.max(), .25)

  sys.stdout.write("Processing")
  sys.stdout.flush()

  print ("Energies %s" %emission_energies)
  print ("ExposureFiles %s " %  rixs.exposure_files)
  rixs.process(emission_energies, progress_callback=progress_cb, skip_columns=skip_columns)
  print("")

  print("Saving...")
  rixs.save(outfile)
  print("Done")


