#!/usr/bin/python

import math
import sys
import cPickle
import numpy
from gmisclib import die
from gmisclib import dictops
from gmisclib import load_mod
import gmisclib.Numeric_gpk as NG
from gmisclib import weighted_percentile as WP

from newstem2 import logtools as LT
from newstem2 import indexclass as IC

TRIM = 0.05
STRETCH = 0.15



def P_bayes(per_fn, argv, m, arg, selector, hdr):
	"""Log(PosteriorMarginalLikelihood) is the posterior marginal likelihood of this model:
	P_posterior[Data|model] =
		integral( P[D|params,model] * P[params|data,model] * d_params)
		it's the average of P[D|params,model] over the posterior distribution
		of P[params|data,model]*Prior(params).
	References are:
	Rampal S. Etienne, Han Olff (2005)
		Confronting different models of community structure to
		species-abundance data: a Bayesian model comparison
		Ecology Letters 8 (5) , 493-504 doi:10.1111/j.1461-0248.2005.00745.x 
	and that references
	M. Aitkin 1991:
		Posterior Bayes Factors, J. of the Royal Statistical Soc. B 53: 111-142
	P. W. Laud and J. G. Ibrahim 1995:
		Predictive Model Selection, J. of the Royal Statistical Soc. B 57: 247-262.
	F. De Santis and F. Spezzaferri 1997:
		Alternative Bayes Factors for Model Selection,
		Canadian J. of Statistics 25: 503-515
	S. K. Upadhyay and M. Peswani 2003:
		Choice between Weibull and Lognormal Models: a simulation based Bayesian Study
		Communications in Statistics: Theory and Methods 32: 318-405
	P. K. Vlachos and A. E. Gelfand 2003:
		On the calibration of Bayeseian model choice criteria,
		J. of Statistical Planning and Inference 111: 223-234
	R. E. Kass and A. E. Rafferty 1995:
		Bayes Factors, J. of the American Statistical Assoc. 90: 773-795


	The other thing, Log(BayesWeightedBayes)
	is the average of P(D|params,M)*Prior(params) over the posterior distribution
	of P[params|data,model]*Prior(params).
	It has no real statistical backing, but it's a crude approximation
	for the Bayes evidence itself,
	the normalized P[params|data,model]*Prior(params).
	"""
	pd = m.pd_factory(argv, hdr=hdr)
	lplist = []
	lblist = []
	for (ol, i) in selector(per_fn):
		idxr = ol.indexers[i]
		try:
			lpdata = pd.logp_data_normalized(idxr)
			lpprior = pd.logp_prior_normalized(idxr)
			lplist.append(lpdata)
			lblist.append(lpdata + lpprior)
		except IC.IndexKeyError, ke:
			LT.print_index_error(ke)
			raise
	rv = []
	mxlp = max(lplist)
	psum = 0.0
	for lp in lplist:
		psum += math.exp(max(-100.0, lp-mxlp))
	rv.append( (('Log(PosteriorMarginalLikelihood)',),
			mxlp + math.log(psum/len(lplist)))
		)
	mxlb = max(lplist)
	bsum = 0.0
	for lb in lblist:
		bsum += math.exp(max(-100.0, lb-mxlb))
	rv.append( (('Log(BayesWeightedBayes)',),
			mxlb + math.log(bsum/len(lblist)))
		)
	return rv



def convergence(per_fn, argv, m, arg, selector, hdr):
	import pylab
	pd = m.pd_factory(argv, hdr=hdr)
	tmp = []
	for (ol, inum) in selector(per_fn):
		pd.convergence(ol.indexers[inum], arg, tmp, pylab, inum)
	pd.convergence(None, arg, tmp, pylab, None)
	pylab.show()



def do_pd_plot(per_fn, argv, m, arg, selector, hdr):
	import pylab
	pd = m.pd_factory(argv, hdr=hdr)
	for (ol, inum) in selector(per_fn):
		idxr = ol.indexers[inum]
		pd.plot(idxr, arg, pylab, inum)
	pd.plot(None, arg, pylab, None)	# This is to allow any post-processing after all the points are computed
	pylab.show()


def truncate(s, n):
	if len(s) > n:
		return s[:max(0,n-3)] + '...'
	return s


def plot_logp(per_fn, argv, m, arg, selector, hdr):
	import pylab
	maxpl = []
	minpl = []
	for (k, ol) in per_fn.items():
		pylab.plot(ol.logps)
		pylab.xlabel('Logged iterations')
		pylab.ylabel('Log probability that model predicts data')
		pylab.title(truncate(' '.join(argv), 40))
		mn, mx = WP.wp(ol.logps, None, [TRIM, 1.0-TRIM])
		minpl.append(mn)
		maxpl.append(mx)
	mn = min(minpl)
	mx = max(maxpl)
	pylab.ylim(mn - STRETCH*(mx-mn), mx + STRETCH*(mx-mn))
	pylab.show()


def do_pd_print(per_fn, argv, m, arg, selector, hdr):
	pd = m.pd_factory(argv, hdr)
	for (ol, inum) in selector(per_fn):
		idxr = ol.indexers[inum]
		pd.do_print(idxr, arg, inum)	# This is to allow any post-processing after all the points are computed
	pd.do_print(None, arg, None)
	sys.stdout.flush()


class ModelEvaluator(object):
	def __init__(self, mod, fcn, arg=None):
		self.mod = mod
		self.fcn = fcn
		self.arg = arg


	def __call__(self, per_fn, hdrs, selector, xargs):
		if 'Argv' in hdrs:
			argv = cPickle.loads(hdrs['Argv'])
		else:
			argv = hdrs['ARGV'].split()
		argv.extend(xargs)
		die.info("Args=%s" % (' '.join(argv)))
		return self.fcn(per_fn, argv, self.mod, self.arg, selector, hdrs)




def compute_pfn_wts(per_fn, selector):
	counts = dictops.dict_of_accums()
	for ol in per_fn.values():
		counts[ol.fname] = 0
	for (ol, inum) in selector(per_fn):
		counts.add(ol.fname, 1)
	return sorted(counts.items())


def print_correlations(per_fn, selector):
	mean, covar, n, idxr_map = LT.indexer_covar(per_fn, selector)
	if covar is None:
		return
	rmap = {}
	for (k, i) in idxr_map.items():
		rmap[i] = k
	evals, evecs = numpy.linalg.eigh(covar)
	m = evals.shape[0]
	mev = 0.5*(evals[m//2]+evals[(m+1)//2])
	for i in range(max(0, m-10), m):
		if evals[i] > max(3*mev, evals[-1]*0.03):
			print '# eigenvalue %d eval= %g (median eval= %g )' % (i, evals[i], mev)
			v = evecs[:,i]
			mxv = NG.N_maximum(numpy.absolute(v))
			tmp = []
			f = math.sqrt(evals[-1]/evals[i]) * 0.15
			for j in range(m):
				if abs(v[j]) > f*mxv:
					tmp.append( (abs(v[j]), IC.index._fmt(rmap[j]), v[j]) )
			tmp.sort()
			for (av, nmj, vj) in tmp[max(0, len(tmp)-15):]:
				print '#\t%.3f * %s' % (vj, nmj)




def process_logs(per_fn, hdr, selector=None, ProbStuff=[], xargs=[]):
	for (fn, nsamp) in compute_pfn_wts(per_fn, selector):
		print '#  samples used =', nsamp, "filename=", fn

	nm, avg, sig = LT.logp_stdev(per_fn, selector)
	if sig is not None:
		print '%.2f +- %.1f %s' % (avg, sig, IC.index._fmt(nm))
	else:
		print '%.2f %s' % (avg, IC.index._fmt(nm))
	for ps in ProbStuff:
		for (nm, pml) in ps(per_fn, hdr, selector, xargs):
			print '%.2f %s' % (pml, IC.index._fmt(nm))

	avglist = LT.indexer_stdev(per_fn, selector)
	print '# n= %d' % len(avglist)
	avglist.sort(lambda a, b: LT.key_cmp(a[0], b[0]))
	for (nm, avg, sig) in avglist:
		if sig is not None:
			# print avg, "+-", sig, IC.index._fmt(nm)
			print '%7g +- %6g %s' % (avg, sig, IC.index._fmt(nm))
		else:
			# print avg, IC.index._fmt(nm)
			print '%7g %s' % (avg, IC.index._fmt(nm))

	print_correlations(per_fn, selector)



def run(arglist):
	Selector = LT.after_convergence
	ProbStuff = []
	Trigger = LT.TRIGGER
	uid = None
	Draw = []
	xargs = []
	while arglist and arglist[0].startswith('-'):
		arg = arglist.pop(0)
		if arg == '-best':
			Selector = LT.overall_best
		elif arg == '-eachbest':
			Selector = LT.each_best
		elif arg == '-good':
			Selector = LT.some_after_convergence
		elif arg == '-eachgood':
			Selector = LT.near_each_max
		elif arg == '-last':
			Selector = LT.last
		elif arg == '-all':
			Selector = LT.all
		elif arg == '-uid':
			uid = arglist.pop(0)
		elif arg == '-xarg':
			xargs.append(arglist.pop(0))
		elif arg == '-plot':
			Draw.append( ModelEvaluator( None, plot_logp))
		elif arg in ['-convergence', '-Convergence']:
			import pylab
			use_sys_path = arg=='-convergence'
			modname = arglist.pop(0)
			convergence_arg = arglist.pop(0)
			Draw.append( ModelEvaluator(load_mod.load_named(modname, use_sys_path),
							convergence,
							convergence_arg,
							)
					)
		elif arg in ['-draw', '-Draw']:
			import pylab
			use_sys_path = arg=='-draw'
			modname = arglist.pop(0)
			draw_arg = arglist.pop(0)
			Draw.append( ModelEvaluator(load_mod.load_named(modname, use_sys_path),
							do_pd_plot,
							draw_arg
							)
						)
		elif arg in ['-print', '-Print']:
			import pylab
			use_sys_path = arg=='-print'
			modname = arglist.pop(0)
			print_arg = arglist.pop(0)
			Draw.append( ModelEvaluator(load_mod.load_named(modname, use_sys_path),
							do_pd_print,
							print_arg
							)
						)
		elif arg == '-fromstart':
			Trigger = None
		elif arg == '--':
			break
		elif arg in ['-ModelCompare', '-modelcompare']:
			use_sys_path = arg=='-modelcompare'
			modname = arglist.pop(0)
			bayes_arg = arglist.pop(0)
			ProbStuff.append( ModelEvaluator(load_mod.load_named(modname, use_sys_path),
							P_bayes,
							bayes_arg
							)
						)
		else:
			die.die('Unrecognized flag: %s' % arg)

	if len(arglist) == 0:
		die.die("Empty argument list!")
	per_fn, hdr = LT.read_many_files(arglist, uid, Nsamp=1000, tail=0.0, trigger=Trigger)
	LT.estimate_convergence(per_fn, LT.FILE_DROP_FAC)

	if len(per_fn)==0:
		die.die("No data has been read in from %s" % arglist)
	for drw in Draw:
		drw(per_fn, hdr, selector=Selector, xargs=xargs)
	else:
		process_logs(per_fn, hdr, selector=Selector, ProbStuff=ProbStuff, xargs=xargs)


if __name__ == '__main__':
	run(sys.argv[1:])
