#!/usr/bin/env python

'''
functions to cluster-match across multiple values of K

for now, it takes in 2 files. they must have adjacent K vals.

author: Aaron Behr
created: 2014-06-29
'''
import sys
import numpy as np
import os
from os import path
import argparse
import time
import json
from shutil import rmtree
import tornado.ioloop
import tornado.web
import tornado.websocket
# from cairosvg import svg2pdf

import pong
from pong import parse, cm, write, align, distruct



clients = []
threads = []
pongdata = None
run_pong_args = None

class Pongdata:
	def __init__(self, intro, outputdir, printall):
		self.runs = {} # contains all Run objects
		self.all_kgroups = [] # contains kgroups in order
		self.cluster_matches = {} # all clustering solutions matching 2 runs

		self.name2id = {} # run name to run ID


		self.num_indiv = -1
		self.K_min = -1
		self.K_max = -1

		self.intro = intro
		self.output_dir = outputdir
		self.print_all = printall

		# self.all_cluster_mixtures = False # should remove this

		self.ind2pop = None
		self.pop_order = None
		self.popcode2popname = None
		self.popindex2popname = None
		self.pop_sizes = None
		self.sort_by = None
		self.indiv_avg = None

		self.colors = [] # use custom colors?

		self.status = 0 # incomplete, working, or complete (0,1,2)


intro  = '-------------------------------------------------------------------\n'
intro += '                            p o n g\n'
# intro += '                 by A. Behr and S. Ramachandran\n'
intro += '          by A. Behr, G. Liu-Fang, K. Liu, and S. Ramachandran\n'
intro += '                       Version 0.4 (2015)\n'
intro += '-------------------------------------------------------------------\n'
intro += '-------------------------------------------------------------------\n'






def main():
	dist_metrics = ['sum_squared', 'percent', 'G', 'jaccard']
	
	parser = argparse.ArgumentParser(description='pong, A. Behr, G. Fang, '
		'K. Liu, and S. Ramachandran')

	parser.add_argument('-m', '--filemap', required=True,
		help='path to params file containing information about input '
		'Q-matrix files')
	parser.add_argument('-c', '--ignore_cols', type=int, default = 0,
		help='ignore the first i columns of every data line. Typically 5 for '
		'Structure output and 0 for ADMIXTURE output. Default = 0')
	parser.add_argument('-o', '--output_dir',
		default = './pong_output', help='specify output dir for files to be '
		'written to. By default, pong makes a folder named "pong_output" in '
		'the current working directory.')

	parser.add_argument('-i','--ind2pop', default=None,
		help='ind2pop data (can be either a Q-matrix column number or the ' 
		'path to a file containing the data).')
	parser.add_argument('-n', '--pop_names', default=None,
		help='Path to file containing population order/names.')
	parser.add_argument('-l','--color_list',
		help='List of colors to be used for visualization. If this file is not '
		'included, then default colors will be used for visualization.')

	parser.add_argument('-f', '--force', default=False,
		action='store_true', help='force overwrite already existing output '
		'directory. By default, pong will prompt the user before overwriting.')

	parser.add_argument('-s', '--sim_threshold', type=float,
		default=0.97, help='choose threshold to combine redundant clusters at '
		'a given K. Default = 0.97')
	parser.add_argument('-d', '--dif_threshold', type=float,
		default=0.1, help='choose threshold for confidence of best cluster for '
		'combining clusters at a given K. Choose 0 in order not to consider '
		'dif. Default = 0.1')
	parser.add_argument('--col_delim', default=None,
		help='Provide the character on which to split columns. Default is '
		'whitespace (of any length).')
	parser.add_argument('--dist_metric',
		default='jaccard', help='distance metric to be used for comparing '
		'cluster similarities. Choose from %s. Default = jaccard' 
		% str(dist_metrics))
	parser.add_argument('--disable_server', default=False, action='store_true',
		help='run pong\'s algorithm without initializing a server instance or '
		'visualizing results.')
	parser.add_argument('-p','--port', type=int, default=4000,
		help='Specify port on which the server should locally host. Default = 4000.')
	parser.add_argument('-v', '--verbose', default=False,
		action='store_true', help='Report more details about clustering '
		'results to the command line, and print all cluster distances in the '
		'output files (by default, only the best 5 are printed).')
	parser.add_argument('--override_prompts', default=False,
		action='store_true', help='Don\'t stop and prompt the user with a '
		'warning when there may be a problem.')

	opts = parser.parse_args()

	# Check validity of pongparams file
	pong_filemap = path.abspath(opts.filemap)
	if not path.isfile(pong_filemap):
		sys.exit('Error: Could not find pong filemap at %s.' % pong_filemap)

	# Set up Q-matrix parsing params

	#parsing Q matrix (different input data)
	# col_delim = opts.col_delim if opts.col_delim else None

	# Check validity of specified distance metric
	if not opts.dist_metric in dist_metrics:
		x = (opts.dist_metric, str(dist_metrics))
		sys.exit('Invalid distance metric: "%s". Please choose from %s' % x)


	# force print_all to true if K=1, where there are <4 results to print.
	# printall = opts.print_all if opts.kmax>2 else True
	# NOTE THAT I SUPRESSED THIS BC WE DON'T KNOW KMAX YET. I'M PRETTY SURE
	# THAT THIS IS NOT NECESSARY BUT UUUGH IDK I SHOULD CHECK!
	printall = opts.verbose

	

	# NEW SECTION ON CHECKING OPTIONAL POP DATA / LABEL PARAMS
	
	ind2pop = None
	labels = None

	if opts.ind2pop is not None: # i would do if opts.ind2pop but what if it's zero?
		try:
			ind2pop = int(opts.ind2pop)
		except ValueError:
			ind2pop = path.abspath(opts.ind2pop)
			if not path.isfile(ind2pop):
				sys.exit('Error: Could not find ind2pop file at %s.' % ind2pop)
	

	if opts.pop_names is not None:
		if ind2pop is None:
			sys.exit('Error: must provide ind to pop data in order to provide '
				'pop order data')
		labels = path.abspath(opts.pop_names)
		if not path.isfile(labels):
			sys.exit('Error: Could not find pop order file at %s.' % labels)




	# Check validity of color file
	colors = []
	color_file = opts.color_list
	if color_file:
		color_file = path.abspath(color_file)
		if not path.isfile(color_file):
			sys.stdout.write('\nWarning: Could not find color file '
				'at %s.\n' % color_file)
			
			if not opts.override_prompts:
				r = raw_input('Continue using default colors? (y/n): ')
				while r not in ('y','Y','n','N'):
					r = raw_input('Please enter "y" to overwrite or '
						'"n" to exit: ')
				if r in ('n','N'): sys.exit(1)
			else:
				sys.stdout.write('Continuing without generating visualization '
					'components.\n')

			color_file = None
		else:
			sys.stdout.write('\nCustom colors provided. Visualization utilizes the '
				'color white.\nIf color file contains white, users are advised to '
				'replace it with another color.\n')
			with open(color_file,'r') as f:
				colors = [x for x in [l.strip() for l in f] if x != '']


	# Check and clean output dir
	outputdir = path.abspath(opts.output_dir)
	if os.path.isdir(outputdir):
		if opts.force:
			rmtree(outputdir)
		else:
			outputdir_name = os.path.split(outputdir)[1]
			print '\nOutput dir %s already exists.' % outputdir_name
			if opts.override_prompts:
				print 'Use option `--force` to overwrite.'
				sys.exit(1)

			r = raw_input('Override? (y/n): ')
			while r not in ('y','Y','n','N'):
				r = raw_input('Please enter "y" to overwrite or "n" to exit: ')
			if r in ('n','N'): sys.exit(1)
			rmtree(outputdir)

	os.makedirs(outputdir)


	# Initialize object to hold references to all main pong data
	global pongdata
	pongdata = Pongdata(intro, outputdir, printall)
	pongdata.colors = colors


	params_used = intro+'\n\n===============\n'
	params_used += 'pong_filemap file: %s\n' % pong_filemap
	params_used += 'Distance metric: %s\n' % opts.dist_metric
	params_used += 'Similarity threshold: %f\n' % opts.sim_threshold
	params_used += 'Difference threshold: %f\n' % opts.dif_threshold
	# params_used += 'Worst choice: %d\n' % opts.worst_choice
	params_used += 'Verbose: %s\n' % str(pongdata.print_all)

	with open(os.path.join(pongdata.output_dir,'params_used.txt'),'w') as f:
		f.write(params_used)


	global run_pong_args
	run_pong_args = (pongdata, opts, pong_filemap, labels, ind2pop)


	# ========================= RUN PONG ======================================

	print pongdata.intro


	if opts.disable_server:
		run_pong(*run_pong_args)
	else:
		app = Application()
		app.listen(opts.port)

		msg = 'pong server is now running locally & listening on port %s\n' % opts.port
		msg += 'Open your web browser and navigate to localhost:%s to see the visualization\n\n'% opts.port
		sys.stdout.write(msg)
		
		try:
			tornado.ioloop.IOLoop.current().start()
		except KeyboardInterrupt:
			sys.stdout.write('\n')
			sys.exit(0)




def run_pong(pongdata, opts, pong_filemap, labels, ind2pop):
	pongdata.status = 1

	t0=time.time()
	# PARSE INPUT FILE AND ORGANIZE DATA INTO GROUPS OF RUNS PER K
	print 'Parsing input and generating cluster network graph'
	parse.parse_multicluster_input(pongdata, pong_filemap, opts.ignore_cols, 
		opts.col_delim, labels, ind2pop)


	# MATCH CLUSTERS FOR RUNS WITHIN EACH K AND CONDENSE TO REPRESENTATIVE RUNS
	print 'Matching clusters within each K and finding representative runs'
	t1 = time.time()
	cm.clumpp(pongdata, opts.dist_metric, opts.sim_threshold, opts.dif_threshold)

	# MATCH CLUSTERS ACROSS K
	print 'Matching clusters across K'
	cm.multicluster_match(pongdata, opts.dist_metric)
	t2 = time.time()

	# PRINT MATCH CLUSTERS RESULTS
	write.output_cluster_match_details(pongdata)
	
	# print(pongdata.name2id)
	# COMPUTE BEST-GUESS ALIGNMENTS FOR ALL RUNS WITHIN AND ACROSS K
	print 'Finding best alignment for all runs within and across K'
	t3 = time.time()
	align.compute_alignments(pongdata, 2,opts.sim_threshold) # use a dummy worst_choice for now until we can remove compatibility
	t4 = time.time()

	# PRINT BEST-GUESS ALIGNMENTS
	write.output_alignments(pongdata)


	# GENERATE COLOR INFO
	# print 'Generating visualization parameters and color details'
	parse.convert_data(pongdata)
	distruct.generate_color_perms(pongdata)
	if len(pongdata.colors) > 0:
		print 'Generating perm files for Distruct'
		distruct.generate_distruct_perm_files(pongdata, pongdata.colors)
	

	pongdata.status = 2
	

	# write.write_json(pongdata)


	print 'match time: %.2fs' % (t2-t1)
	print 'align time: %.2fs' % (t4-t3)
	print 'total time: %.2fs' % ((t2-t0)+(t4-t3))






class Application(tornado.web.Application):
	def __init__(self):
		handlers = [
			(r"/", MainHandler),
			(r"/pongsocket", WSHandler),
		]
		settings = dict(
			template_path=path.join(pong.__path__[0], "templates"),
			static_path=path.join(pong.__path__[0], "static"),
		)
		tornado.web.Application.__init__(self, handlers, **settings)


class MainHandler(tornado.web.RequestHandler):
	def get(self):
		self.render("pong.html")

class WSHandler(tornado.websocket.WebSocketHandler):
	global pongdata
	clients = set()

	def open(self):
		print 'New browser connection'
		WSHandler.clients.add(self)
		
		# Server is not asynchronous so it won't serve a partially-completed Pong object
		if pongdata.status == 0:
			global run_pong_args
			run_pong(*run_pong_args)
		
		pong_json_data = write.write_json(pongdata)

		self.write_message(json.dumps({'type': 'pong-data',
			'pong': pong_json_data},))

		print 'Generating visualization'



	def on_close(self):
		WSHandler.clients.remove(self)
		print 'Browser disconnected'

	# @classmethod
	# def update(cls, data):
	#	 for client in cls.clients:
	#		 client.write_message(data)

	def on_message(self, message):
		# logging.info("received message")

		data = json.loads(message)
		data = tornado.escape.json_decode(message)

		if data['type'] == 'button-clicked':
			# the tornado chat example has a diff way of doing this
			print 'received button click %s from client' % data['info']
			self.write_message(json.dumps({'type': 'button-response',
				'response':'nm u?'}))

		elif data['type'] == 'get-qmatrix': #received call from client on_message getQmatrix function call
			name = data['name']
			run = pongdata.runs[pongdata.name2id[name]] #returns run instance
			minor = data['minor']
			minorID = data['minorID']
			is_first = data['is_first']

			# print 'server received request for Q-matrix %s. Column perm %s.' % (name, str(run.alignment-1))
			# self.write_message(json.dumps({'type': 'q-matrix', 'name': name, 'K': run.K,
				# 'matrix': np.array([run.data[i] for i in run.alignment-1]).transpose().tolist()}))

			if minor=='yes':
				response = {'type':'q-matrix', 'name':name, 'K':run.K,
					'matrix2d':run.data_transpose_2d, 'minor':'yes', 'minorID':minorID, 'is_first':is_first} #'matrix3d':run.data_transpose_3d} 
			else:
				response = {'type':'q-matrix', 'name':name, 'K':run.K,
					'matrix2d':run.data_transpose_2d, 'minor':'no', 'minorID': None, 'is_first':None} #'matrix3d':run.data_transpose_3d} 
			

			self.write_message(json.dumps(response))

		elif data['type'] == 'svg':
			dl_dir = path.expanduser('~/Downloads')
			if path.isdir(dl_dir):
				print 'Saving file %s.svg to your Downloads folder.' % data['name']
			else:
				dl_dir = pongdata.output_dir
				print 'Could not find Downloads folder; saving file %s.svg to the output dir.' % data['name']

			# filename = '.'+data['name']+'.svg'
			# svg2pdf(data['svg'], write_to=path.join(dl_dir, filename))
			with open(path.join(dl_dir,data['name']+'.svg'),'w') as f: f.write(data['svg'])
			# os.rename(path.join(dl_dir, filename), path.join(dl_dir, filename[1:])) # do not show file until it is complete

		elif data['type'] == 'multi-svg':
			svg_dir = path.join(pongdata.output_dir, 'plotSVGs')
			print 'Saving plot SVGs in %s.' % svg_dir

			# prepare SVG dir within output dir
			if path.isdir(svg_dir): rmtree(svg_dir)
			os.makedirs(svg_dir)

			for name,svg in data['svg-dict'].items():
				name = path.join(svg_dir,name)+'.svg'
				with open(name, 'w') as f: f.write(svg)


		else:
			sys.exit('Error: Received invalid socket message from client')




	# @classmethod
	# def stream_something():
	#	 print 'streaming something to start'
	#	 data = {'type' : 'stream-from-server', 'data': 'here ya go'}

	#	 for client in cls.clients:
	#		 client.write_message(data)













if __name__ == '__main__':
	main()
