#!/usr/bin/env python
"""
ClinPhen commandline.
"""
import sys
import os
import logging
import time
import csv
import collections
from StringIO import StringIO
import multiprocessing
import argparse
import functools

import pandas as pd
from clinphen_src.src_dir import *
srcDir = get_src_dir()
from clinphen_src.get_phenotypes import *

myDir = os.path.dirname(os.path.realpath(__file__))
hpo_main_names = os.path.join(srcDir, "data/hpo_term_names.txt")
assert os.path.isfile(hpo_main_names), "Cannot find HPO term names mapping file {}".format(hpo_main_names)

def getNames(fname=hpo_main_names):
  returnMap = {}
  for line in open(fname):
    lineData = line.strip().split("\t")
    returnMap[lineData[0]] = lineData[1]
  return returnMap

def yield_pat_notes_from_file(fname, patient_col, note_col, delimiter="|", header=True):
  """
  Given a file containing patient notes in a tabular format WITH HEADERS, yield tuples of (patient, note)
  Note that it is very important that we use yield here isntead of building the full list and returning it
  """
  if header:
    patient_col = str(patient_col)
    note_col = str(note_col)
  else:
    patient_col = int(patient_col)
    note_col = int(note_col)

  note_counter = 0
  t_start = time.time()
  with open(fname, mode='r') as source:
    if header:  # Use a dict reader and index by column name if we have a header
      reader = csv.DictReader(source, delimiter=delimiter, quotechar='"', skipinitialspace=True, quoting=csv.QUOTE_ALL)
    else:  # Use a standard reader and index by numeric index if we do not have a header
      reader = csv.reader(source, delimiter=delimiter, quotechar='"', skipinitialspace=True, quoting=csv.QUOTE_ALL)
    for item in reader:
      note_counter += 1
      if note_counter % 1000 == 0:
        t_delta = time.time() - t_start
        rate = note_counter / t_delta
        logging.info("{} total records consumed\t{} records per second".format(note_counter, rate))
      
      yield item[patient_col], item[note_col]
  raise StopIteration

def _extract_phenotypes_wrapper(pat_record_tuple, names, hpo_syn_file=HPO_SYN_MAP_FILE):
  """
  Helper function
  Calls extract_phenotypes, and reformats the output as a dataframe and preprends a patient column to it
  Designed to enable record-level parallelism
  """
  patient, record = pat_record_tuple  # Unpack
  str_result = extract_phenotypes(record, names, hpo_syn_file)
  df = pd.read_table(StringIO(str_result))  # Parse string return value as a dataframe
  df.insert(0, '#Patient ID', patient)  # Prepend patient column
  df.drop(columns=['Example sentence'], inplace=True)  # Drop the example sentence so we don't leak PHI
  return patient, df

def _df_combine_and_sort(dataframes):
  """
  Given a list of dataframes, combine and sort them
  """
  assert dataframes  # Cannot be an empty list
  combined = pd.concat(dataframes)
  # Merge together multiple observations of the same term
  combined = combined.groupby(["#Patient ID", "HPO ID", 'Phenotype name']).agg({
    'No. occurrences': 'sum',
    'Earliness (lower = earlier)': 'min',
  }, axis=0).reset_index()  # Combine rows
  # Num occurrences is sorted in descending order, earliness in ascending order
  combined.sort_values(by=['No. occurrences', 'Earliness (lower = earlier)'], ascending=[False, True], inplace=True)
  return combined

def build_parser():
  """Build commandline parser"""
  parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument("recordfile", type=str, help="Tabular file containing records to parse")
  parser.add_argument("outfile", type=str, help="Output tsv file")
  parser.add_argument("--delim", type=str, default="|", help="Delimiter in recordfile")
  parser.add_argument("--patcol", type=str, default='MRN', help="Column for patient identifiers. Should be integer if recordfile has no header")
  parser.add_argument("--notecol", type=str, default='NOTE', help="Column for notes. Should be integer if recordfile has no header")
  parser.add_argument("--noheader", action="store_true", help="Enable flag if recordfile has no headers")
  parser.add_argument("--thesaurus", type=str, required=False, default=HPO_SYN_MAP_FILE, help="Custom thesaurus to use")
  parser.add_argument("--threads", type=int, default=multiprocessing.cpu_count(), help='Number of threads to run')
  return parser

def main():
  """Runs the script"""
  # Parse commandline arguments
  parser = build_parser()
  args = parser.parse_args()
  csv.field_size_limit(sys.maxsize / args.threads)

  hpo_to_name = getNames()

  # Create a version of extract_phenotypes that has some arguments fixed
  # This is necessary because parallel map can only take in single-arg functions
  pfunc = functools.partial(_extract_phenotypes_wrapper, names=hpo_to_name, hpo_syn_file=args.thesaurus)

  # https://stackoverflow.com/questions/26520781/multiprocessing-pool-whats-the-difference-between-map-async-and-imap
  logging.info("Processing notes with {} threads".format(args.threads))
  pool = multiprocessing.Pool(args.threads)
  processed = pool.imap_unordered(pfunc, yield_pat_notes_from_file(args.recordfile, args.patcol, args.notecol, header=not args.noheader))  # Returns an iterable
  patient_aggregator = collections.defaultdict(list)
  for pair in processed:
    patient, df = pair
    if not df.empty:  # Append if not empty
      patient_aggregator[patient].append(df)

  # Aggregate each patient's subtables into a single per-patient table, in order
  per_patient_dfs = pool.map(_df_combine_and_sort, (patient_aggregator[k] for k in sorted(patient_aggregator.keys())))  # Combine and sort each dataframe in parallel
  # Combine all the per-patient tables into one large table
  total_table = pd.concat(per_patient_dfs)

  # Close the multiprocessing pool
  pool.close()
  pool.join()

  # Write output
  total_table.to_csv(args.outfile, sep='\t', index=False)

if __name__ == "__main__":
  logging.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S'
  )
  main()
