#!/usr/bin/env python

'''
functions to cluster-match across multiple values of K

for now, it takes in 2 files. they must have adjacent K vals.

author: Aaron Behr
created: 2014-06-29
'''
import sys
import numpy as np
import os
from os import path
import argparse
import time
import json
import tornado.ioloop
import tornado.web
import tornado.websocket
from cairosvg import svg2pdf

import pong
from pong import parse, cm, write, align, distruct


clients = []
threads = []
pongdata = None
run_pong_args = None

class Pongdata:
	def __init__(self, intro, outputdir, printall):
		self.runs = {} # contains all Run objects
		self.all_kgroups = [] # contains kgroups in order
		self.cluster_matches = {} # all clustering solutions matching 2 runs

		self.name2id = {} # run name to run ID


		self.num_indiv = -1
		self.K_min = -1
		self.K_max = -1

		self.intro = intro
		self.output_dir = outputdir
		self.print_all = printall

		# self.all_cluster_mixtures = False # should remove this

		self.ind2pop = None
		self.pop_order = None
		self.popcode2popname = None
		self.popindex2popname = None
		self.pop_sizes = None
		self.sort_by = None
		self.indiv_avg = None

		self.colors = [] # use custom colors?

		self.status = 0 # incomplete, working, or complete (0,1,2)


intro  = '-------------------------------------------------------------------\n'
intro += '                            p o n g\n'
# intro += '                 by A. Behr and S. Ramachandran\n'
intro += '          by A. Behr, G. Fang, K. Liu, and S. Ramachandran\n'
intro += '                       Version 0.4 (2015)\n'
intro += '-------------------------------------------------------------------\n'
intro += '-------------------------------------------------------------------\n'






def main():
	dist_metrics = ['sum_squared', 'percent', 'G', 'jaccard']
	
	parser = argparse.ArgumentParser(description='pong, by A. Behr and S. '
		'Ramachandran')

	parser.add_argument('-m', '--filemap', required=True,
		help='path to params file containing information about input '
		'Q-matrix files')
	
	parser.add_argument('-c', '--ignore_cols', type=int, default = 0,
		help='ignore the first i columns of every data line. Typically 5 for '
		'Structure output and 0 for ADMIXTURE output. Default = 0')
	parser.add_argument('--col_delim', default=None,
		help='Provide the character on which to split columns. Default is '
		'whitespace (of any length).')
	parser.add_argument('--dist_metric',
		default='jaccard', help='distance metric to be used for comparing '
		'cluster similarities. Choose from %s. Default = jaccard' 
		% str(dist_metrics))
	parser.add_argument('-s', '--sim_threshold', type=float,
		default=0.97, help='choose threshold to combine redundant clusters at '
		'a given K. Default = 0.97')
	parser.add_argument('-d', '--dif_threshold', type=float,
		default=0.1, help='choose threshold for confidence of best cluster for '
		'combining clusters at a given K. Choose 0 in order not to consider '
		'dif. Default = 0.1')
	# parser.add_argument('-w', '--worst_choice', type=int,
		# default=2, help='choose threshold for best-guess matching clusters '
		# 'when no valid initial match is found. Default = 2')
	parser.add_argument('-o', '--output_dir',
		default = './pong_output', help='specify output dir for files to be '
		'written to. Default = "./pong_output"')
	parser.add_argument('-v', '--print_all', default=False,
		action='store_true', help='print all cluster distances. By deafult, '
		'only the best 5 are printed.')
	parser.add_argument('-f', '--force', default=False,
		action='store_true', help='force overwrite already existing output '
		'directory. By default, pong will prompt the user before overwriting.')
	parser.add_argument('-l','--color_list',
		help='List of colors to be used for visualization. If this file is not '
		'included, then default colors will be used for visualization. '
		'This file must include at least max_K colors; '
		'if more are given then the first max_K colors are used.')
	parser.add_argument('--override_prompts', default=False,
		action='store_true', help='Don\'t stop and prompt the user with a '
		'warning when there may be a problem.')

	parser.add_argument('-n', '--pop_names', default=None,
		help='order for population names (fp)')
	parser.add_argument('-i','--ind2pop', default=None,
		help='ind2pop data (col num for structure, or a filepath for admixture')


	parser.add_argument('--disable_server', default=False, action='store_true',
		help='Do not run the server, just run pong')
	parser.add_argument('-p','--port', type=int, default=4000,
		help='Specify port on which the server should locally host. Default = 4000.')

	opts = parser.parse_args()

	# Check validity of pongparams file
	pong_filemap = path.abspath(opts.filemap)
	if not path.isfile(pong_filemap):
		sys.exit('Error: Could not find pong filemap at %s.' % pong_filemap)

	# Set up Q-matrix parsing params

	#parsing Q matrix (different input data)
	# col_delim = opts.col_delim if opts.col_delim else None

	# Check validity of specified distance metric
	if not opts.dist_metric in dist_metrics:
		x = (opts.dist_metric, str(dist_metrics))
		sys.exit('Invalid distance metric: "%s". Please choose from %s' % x)


	# force print_all to true if K=1, where there are <4 results to print.
	# printall = opts.print_all if opts.kmax>2 else True
	# NOTE THAT I SUPRESSED THIS BC WE DON'T KNOW KMAX YET. I'M PRETTY SURE
	# THAT THIS IS NOT NECESSARY BUT UUUGH IDK I SHOULD CHECK!
	printall = opts.print_all

	

	# NEW SECTION ON CHECKING OPTIONAL POP DATA / LABEL PARAMS
	
	ind2pop = None
	labels = None

	if opts.ind2pop is not None: # i would do if opts.ind2pop but what if it's zero?
		try:
			ind2pop = int(opts.ind2pop)
		except ValueError:
			ind2pop = path.abspath(opts.ind2pop)
			if not path.isfile(ind2pop):
				sys.exit('Error: Could not find ind2pop file at %s.' % ind2pop)
	

	if opts.pop_names is not None:
		if ind2pop is None:
			sys.exit('Error: must provide ind to pop data in order to provide '
				'pop order data')
		labels = path.abspath(opts.pop_names)
		if not path.isfile(labels):
			sys.exit('Error: Could not find pop order file at %s.' % labels)




	# Check validity of color file
	colors = []
	color_file = opts.color_list
	if color_file:
		color_file = path.abspath(color_file)
		if not path.isfile(color_file):
			sys.stdout.write('\nWarning: Could not find color file '
				'at %s.\n' % color_file)
			
			if not opts.override_prompts:
				r = raw_input('Continue using default colors? (y/n): ')
				while r not in ('y','Y','n','N'):
					r = raw_input('Please enter "y" to overwrite or '
						'"n" to exit: ')
				if r in ('n','N'): sys.exit(1)
			else:
				sys.stdout.write('Continuing without generating visualization '
					'components.\n')

			color_file = None
		else:
			with open(color_file,'r') as f:
				colors = [x for x in [l.strip() for l in f] if x != '']


	# Check and clean output dir
	outputdir = path.abspath(opts.output_dir)
	if os.path.isdir(outputdir):
		import shutil
		if opts.force:
			shutil.rmtree(outputdir)
		else:
			outputdir_name = os.path.split(outputdir)[1]
			print '\nOutput dir %s already exists.' % outputdir_name
			if opts.override_prompts:
				print 'Use option `--force` to overwrite.'
				sys.exit(1)

			r = raw_input('Override? (y/n): ')
			while r not in ('y','Y','n','N'):
				r = raw_input('Please enter "y" to overwrite or "n" to exit: ')
			if r in ('n','N'): sys.exit(1)
			shutil.rmtree(outputdir)

	os.makedirs(outputdir)


	# Initialize object to hold references to all main pong data
	global pongdata
	pongdata = Pongdata(intro, outputdir, printall)
	pongdata.colors = colors


	params_used = intro+'\n\n===============\n'
	params_used += 'pong_filemap file: %s\n' % pong_filemap
	params_used += 'Distance metric: %s\n' % opts.dist_metric
	params_used += 'Similarity threshold: %f\n' % opts.sim_threshold
	params_used += 'Difference threshold: %f\n' % opts.dif_threshold
	# params_used += 'Worst choice: %d\n' % opts.worst_choice
	params_used += 'Verbose: %s\n' % str(pongdata.print_all)

	with open(os.path.join(pongdata.output_dir,'params_used.txt'),'w') as f:
		f.write(params_used)


	global run_pong_args
	run_pong_args = (pongdata, opts, pong_filemap, labels, ind2pop)


	# ========================= RUN PONG ======================================

	print pongdata.intro


	if opts.disable_server:
		run_pong(*run_pong_args)
	else:
		app = Application()
		app.listen(opts.port)

		msg = 'pong server is now running locally & listening on port %s\n' % opts.port
		msg += 'Open your web browser and navigate to localhost:%s to see the visualization\n\n'% opts.port
		sys.stdout.write(msg)
		
		try:
			tornado.ioloop.IOLoop.current().start()
		except KeyboardInterrupt:
			sys.stdout.write('\n')
			sys.exit(0)




def run_pong(pongdata, opts, pong_filemap, labels, ind2pop):
	pongdata.status = 1

	# PARSE INPUT FILE AND ORGANIZE DATA INTO GROUPS OF RUNS PER K
	print 'Parsing input and generating cluster network graph'
	parse.parse_multicluster_input(pongdata, pong_filemap, opts.ignore_cols, 
		opts.col_delim, labels, ind2pop)


	# MATCH CLUSTERS FOR RUNS WITHIN EACH K AND CONDENSE TO REPRESENTATIVE RUNS
	print 'Matching clusters within each K and finding representative runs'
	cm.clumpp(pongdata, opts.dist_metric, opts.sim_threshold, opts.dif_threshold)

	# MATCH CLUSTERS ACROSS K
	print 'Matching clusters across K'
	t0 = time.time()
	cm.multicluster_match(pongdata, opts.dist_metric)
	t1 = time.time()

	# PRINT MATCH CLUSTERS RESULTS
	write.output_cluster_match_details(pongdata)
	
	# print(pongdata.name2id)
	# COMPUTE BEST-GUESS ALIGNMENTS FOR ALL RUNS WITHIN AND ACROSS K
	print 'Finding best alignment for all runs within and across K'
	t2 = time.time()
	align.compute_alignments(pongdata, 2,opts.sim_threshold) # use a dummy worst_choice for now until we can remove compatibility
	t3 = time.time()

	# PRINT BEST-GUESS ALIGNMENTS
	write.output_alignments(pongdata)


	# GENERATE COLOR INFO
	# print 'Generating visualization parameters and color details'
	parse.convert_data(pongdata)
	distruct.generate_color_perms(pongdata)
	if len(pongdata.colors) > 0:
		print 'Generating perm files for Distruct'
		distruct.generate_distruct_perm_files(pongdata, pongdata.colors)
	

	pongdata.status = 2


	# write.write_json(pongdata)


	# print 'match time: %.2fs' % (t1-t0)
	# print 'align time: %.2fs' % (t3-t2)







class Application(tornado.web.Application):
	def __init__(self):
		handlers = [
			(r"/", MainHandler),
			(r"/pongsocket", WSHandler),
		]
		settings = dict(
			template_path=path.join(pong.__path__[0],'templates'),
			static_path=path.join(pong.__path__[0],'static'),
		)
		tornado.web.Application.__init__(self, handlers, **settings)


class MainHandler(tornado.web.RequestHandler):
	def get(self):
		self.render("pong.html")


class WSHandler(tornado.websocket.WebSocketHandler):
	global pongdata
	clients = set()

	def open(self):
		print 'New browser connection'
		WSHandler.clients.add(self)
		
		# Server is not asynchronous so it won't serve a partially-completed Pong object
		if pongdata.status == 0:
			global run_pong_args
			run_pong(*run_pong_args)
		
		pong_json_data = write.write_json(pongdata)

		self.write_message(json.dumps({'type': 'pong-data',
			'pong': pong_json_data},))

		print 'Generating visualization'



	def on_close(self):
		WSHandler.clients.remove(self)
		print 'Browser disconnected'

	# @classmethod
	# def update(cls, data):
	#	 for client in cls.clients:
	#		 client.write_message(data)

	def on_message(self, message):
		# logging.info("received message")

		data = json.loads(message)
		data = tornado.escape.json_decode(message)

		if data['type'] == 'button-clicked':
			# the tornado chat example has a diff way of doing this
			print 'received button click %s from client' % data['info']
			self.write_message(json.dumps({'type': 'button-response',
				'response':'nm u?'}))

		elif data['type'] == 'get-qmatrix': #received call from client on_message getQmatrix function call
			name = data['name']
			run = pongdata.runs[pongdata.name2id[name]] #returns run instance
			minor = data['minor']
			minorID = data['minorID']
			is_first = data['is_first']

			# print 'server received request for Q-matrix %s. Column perm %s.' % (name, str(run.alignment-1))
			# self.write_message(json.dumps({'type': 'q-matrix', 'name': name, 'K': run.K,
				# 'matrix': np.array([run.data[i] for i in run.alignment-1]).transpose().tolist()}))

			if minor=='yes':
				response = {'type':'q-matrix', 'name':name, 'K':run.K,
					'matrix2d':run.data_transpose_2d, 'minor':'yes', 'minorID':minorID, 'is_first':is_first} #'matrix3d':run.data_transpose_3d} 
			else:
				response = {'type':'q-matrix', 'name':name, 'K':run.K,
					'matrix2d':run.data_transpose_2d, 'minor':'no', 'minorID':None, 'is_first':None} #'matrix3d':run.data_transpose_3d} 
			

			self.write_message(json.dumps(response))

		elif data['type'] == 'svg':
			print 'Creating file %s.pdf in the output dir.' % data['name']
			filename = '.'+data['name']+'.pdf'
			svg2pdf(data['svg'], write_to=path.join(pongdata.output_dir, filename))
			os.rename(path.join(pongdata.output_dir, filename), path.join(pongdata.output_dir, filename[1:])) # do not show file until it is complete


			# with open(filename,'w') as f: f.write(data['svg'])
			# renderPDF.drawToFile(svg2rlg(data['svg']), filename)


		else:
			sys.exit('Error: Received invalid socket message from client')




	# @classmethod
	# def stream_something():
	#	 print 'streaming something to start'
	#	 data = {'type' : 'stream-from-server', 'data': 'here ya go'}

	#	 for client in cls.clients:
	#		 client.write_message(data)













if __name__ == '__main__':
	main()
