#!/usr/bin/python3
# -*- coding: utf-8 -*-


import os, sys, time
import argparse
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from tkinter import *


def get_parser() :
	# Fonction permettant de pourvoir demander des arguments

	parser = argparse.ArgumentParser(description= \
		"run GENIAL for several databases", \
			formatter_class=argparse.ArgumentDefaultsHelpFormatter)

	parser.add_argument("-f", action="store", dest="input", 
					type=str, required=True, help="tsv file with FASTA files paths and strains IDs (REQUIRED)")

	parser.add_argument("--db", action="store", dest="databases", \
						type=str, required=True, nargs="+", choices=["resfinder", "card",	"argannot", "ecoh", \
							"ecoli_vf", "plasmidfinder", "vfdb", "ncbi", "enterotox_staph", "phages"], help="default \
								database list (resfinder, card, argannot, ecoh, ecoli_vf, plasmidfinder, vfdb)")

	parser.add_argument("-T", action="store", dest="nbThreads", 
					type=str, default="1", help="number of theard to use")

	parser.add_argument("-w", action="store", dest="workdir", 
		type=str, default=".", help="working directory")

	parser.add_argument("-r", action="store", dest="resdir", 
					type=str, default="ABRicate_results", help="results directory name")

	parser.add_argument("--mincov", action="store", dest="mincov", \
						type=str, default="80", help="Minimum proportion of gene covered")

	parser.add_argument("--minid", action="store", dest="minid", \
						type=str, default="90", help="Minimum proportion of exact nucleotide matches")

	parser.add_argument("--R", action="store_true", dest="remove",
					default=False, help="remove genes present in all genomes from the matrix (default:False)")

	return parser


def heatmap(matrix, heatmapName, dirHeatmap, nbGenes) :
	
	#if nbGenes > nbGenomes :
	try : 
		heatmap = sns.clustermap(matrix, metric="euclidean", method="average", figsize=(round(0.23*nbGenes + 0.66), round(0.23*nbGenes + 0.66)), linewidths=.003, col_cluster = False, yticklabels=True, cmap = "Greys")  # construction de la heatmap
		heatmap.savefig(dirHeatmap + heatmapName) # sauvegarde de la heatmap dans un png

	except RecursionError : 
		print("Erreur lors de la construction de la heatmap, la matrice est trop grande") # affichage d'un message d'erreur si la matrice est trop grande

	except TclError :
		print("Erreur : La construction de la heatmap nécéssite trop de mémoire") # affichage d'un message d'erreur si la matrice est trop grande

	# else :
	# 	try : 
	# 		heatmap = sns.clustermap(matrix, metric="euclidean", method="average", figsize=(round(0.23*nbGenomes + 0.66), round(0.23*nbGenomes + 0.66)), linewidths=.003, col_cluster = False, yticklabels=True, cmap = "Greys")  # construction de la heatmap
	# 		heatmap.savefig(dirHeatmap + heatmapName) # sauvegarde de la heatmap dans un png

	# 	except RecursionError : 
	# 		print("Erreur lors de la construction de la heatmap, la matrice est trop grande") # affichage d'un message d'erreur si la matrice est trop grande

	# 	except TclError :
	# 		print("Erreur : La construction de la heatmap nécéssite trop de mémoire") # affichage d'un message d'erreur si la matrice est trop grande


def main():
	
	##################### gets arguments #####################

	parser=get_parser()
	
	#print parser.help if no arguments
	if len(sys.argv)==1:
		parser.print_help()
		sys.exit(1)
	
	# mettre tout les arguments dans la variable Argument
	Arguments=parser.parse_args()

	begin = time.time()


	if Arguments.workdir[-1] != "/" :
		Arguments.workdir += "/"

	if Arguments.resdir[-1] != "/" : 
		Arguments.resdir += "/"

	WORKDIR = Arguments.workdir
	RESDIR = Arguments.resdir

	if not os.path.exists(WORKDIR + RESDIR) :
		os.system("mkdir " + WORKDIR + RESDIR)
	
	allMatrix = []

	for db in Arguments.databases :

		newGenesNames = []

		os.system("python GENIALanalysis -f " +  Arguments.input + " -T " + Arguments.nbThreads + " --defaultdb " + db + " -w " + WORKDIR + RESDIR + " -r " + db + " --minid " + Arguments.minid + " --mincov " + Arguments.mincov)

		if os.path.isfile(WORKDIR + RESDIR + db + "/ABRicate_files.tsv") :  # Vérifie que le fichier abricate_files.tsv existe
			if Arguments.remove :
				os.system("python GENIALresults -f " + WORKDIR + RESDIR + db + "/ABRicate_files.tsv " + " --defaultdb " + db + " -w " + WORKDIR + RESDIR + " -r " + db + " --R")

			else :
				os.system("python GENIALresults -f " + WORKDIR + RESDIR + db + "/ABRicate_files.tsv " + " --defaultdb " + db + " -w " + WORKDIR + RESDIR + " -r " + db)
		
		matrix  = pd.read_csv(WORKDIR + RESDIR + db + "/Matricies/matrixAllGenes.tsv", sep = "\t", index_col = 0, dtype = str)

		genes = matrix.columns # liste des gènes

		for gene in genes : # pour chaque gène

			newGeneName = db + " | " + gene # association du gène à sa base de données
			newGenesNames.append(newGeneName)

		matrix.columns = newGenesNames # renomages des colonnes de la matrice

		allMatrix.append(matrix) # liste de toutes les matrices  à concaténer

	matrix = pd.concat(allMatrix, axis=1)	# concaténation des matrices

	matrix.to_csv(WORKDIR + RESDIR + "matrixAllDatabases.tsv", sep="\t", index = True) # écriture de la matrice dans un fichier tsv

	matrix = matrix[matrix.columns].astype(int)

	heatmap(matrix, "heatmapAllDatabases.png", WORKDIR + RESDIR, len(genes))

	end = time.time()

	print ("Temps d'exécution total : " + str(round(end - begin,3)) + " secondes")

# lancer la fonction main()  au lancement du script
if __name__ == "__main__":
	main()	