#!/usr/bin/env python

import sys
import os
import argparse
import logging
import warnings

import loompy
from loom_viewer import start_server, LoomTiles, LoomExpand

import numpy as np

import timeit
import random


class VerboseArgParser(argparse.ArgumentParser):
	def error(self, message):
		self.print_help()
		sys.stderr.write('\nerror: %s\n' % message)
		sys.exit(2)

def connect_loom(file_path):
	if os.path.exists(file_path):
		logging.info("  Connecting to %s" % file_path)
		return loompy.connect(file_path)
	logging.warn("  Could not find %s" % file_path)
	return None

def list_filename_matches(dataset_path, filename):
	logging.info('Looking for Loom files matching %s' % filename)
	projects = [x for x in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, x)) and not x.startswith(".")]
	logging.info('Found %i projects' % len(projects))
	matching_files = []
	for project in projects:
		project_path = os.path.join(dataset_path, project)
		logging.info('Entering project %s' % project_path)
		if os.path.exists(project_path):
			project_files = os.listdir(project_path)
			file_path = os.path.join(project_path, filename)
			if os.path.isfile(file_path):
				logging.info('  Found a matching Loom file')
				matching_files.append((project, filename, file_path))
	return matching_files

def list_project_files(dataset_path, project):
	logging.info('Listing all Loom files in %s' % project)
	project_path = os.path.join(dataset_path, project)
	if os.path.exists(project_path):
		project_files = os.listdir(project_path)
		loom_files = [(project, filename, os.path.join(project_path, filename)) for filename in project_files if filename.endswith(".loom")]

		total_loom_files = len(loom_files)
		if total_loom_files is 0:
			logging.info("No loom files found in %s folder" % project)
		elif total_loom_files is 1:
			logging.info("Found 1 loom file")
		else:
			logging.info("Found %i loom files" % total_loom_files)

		return loom_files
	else:
		raise warnings.warn("%s is not a path to a Project folder!" % project_path)

def tile_command(dataset_path, filename, truncate):
	loom_files = list_filename_matches(dataset_path, filename)
	for project, filename, file_path in loom_files:
		ds = None
		try:
			ds = connect_loom(file_path)
			if ds == None:
				raise warnings.warn("Could not connect to %s" % file_path)
		except Exception as e:
			logging.error(e)
			return
		try:
			logging.info("    Precomputing heatmap tiles, stored in subfolder:\n    %s.tiles" % file_path)
			tiles = LoomTiles(ds)
			tiles.prepare_heatmap(truncate)
		except Exception as e:
			logging.error(e)

def tile_project_command(dataset_path, project, truncate):
	loom_files = list_project_files(dataset_path, project)
	for _, filename, file_path in loom_files:
		ds = None
		try:
			ds = connect_loom(file_path)
			if ds == None:
				raise warnings.warn("Could not connect to %s" % file_path)
		except Exception as e:
			logging.error(e)
			return
		try:
			logging.info("    Precomputing %s heatmap tiles, stored in subfolder:\n    %s.tiles" % filename, file_path)
			tiles = LoomTiles(ds)
			tiles.prepare_heatmap(truncate)
		except Exception as e:
			logging.error(e)

def expand_command(dataset_path, filename, truncate, metadata, attributes, rows, cols):
	if not (metadata or attributes or rows or cols):
		logging.info('Must explicitly state what to expand!')
		return
	loom_files = list_filename_matches(dataset_path, filename)
	for project, _, file_path in loom_files:
		try:
			ds = connect_loom(file_path)
			if ds == None:
				raise warnings.warn("Could not connect to %s" % file_path)
			expand = LoomExpand(ds, dataset_path, project, filename, file_path)
			if metadata:
				expand.metadata(truncate)
			if attributes:
				expand.attributes(truncate)
			if rows:
				expand.rows(truncate)
			if cols:
				expand.columns(truncate)
		except Exception as e:
			logging.error(e)

def expand_project_command(dataset_path, project, truncate, metadata, attributes, rows, cols):
	loom_files = list_project_files(dataset_path, project)
	for _, filename, file_path in loom_files:
		try:
			ds = connect_loom(file_path)
			if ds == None:
				raise warnings.warn("Could not connect to %s" % file_path)
			expand = LoomExpand(ds, dataset_path, project, filename, file_path)
			if metadata:
				expand.metadata(truncate)
			if attributes:
				expand.attributes(truncate)
			if rows:
				expand.rows(truncate)
			if cols:
				expand.columns(truncate)
		except Exception as e:
			logging.error(e)

def expand_all_command(dataset_path, truncate, metadata, attributes, rows, cols):

	if not (metadata or attributes or rows or cols):
		logging.info('Must explicitly state what to expand!')
		return
	logging.info('Searching for projects in %s' % dataset_path)
	projects = [x for x in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, x)) and not x.startswith(".")]
	logging.info('  Found %i projects' % len(projects))
	for project in projects:
		logging.info('Entering project %s' % project)
		loom_files = list_project_files(dataset_path, project)
		for _, filename, file_path in loom_files:
			try:
				ds = connect_loom(file_path)
				if ds == None:
					raise warnings.warn("Could not connect to %s" % file_path)
				expand = LoomExpand(ds, dataset_path, project, filename, file_path)
				if metadata:
					expand.metadata(truncate)
				if attributes:
					expand.attributes(truncate)
				if rows:
					expand.rows(truncate)
				if cols:
					expand.columns(truncate)
			except Exception as e:
				logging.error(e)

def fromloom_command(dataset_path, infile, outfile, project, dtype, _chunks, chunk_cache, compression_opts, shuffle, fletcher32):
	if project != None:
		outfile = os.path.join(dataset_path, project, outfile)

	if dtype == None:
		ds = loompy.connect(infile, 'r')
		dtype = ds.file['matrix'].type

	chunks = (_chunks,_chunks)
	if type(chunks) != tuple or len(chunks) != 2:
		chunks = (64,64)

	logging.info("Converting %s to %s" % (infile, outfile))
	logging.info("dtype: %s, chunks size: %s, chunk cache: %s, compression: %s, shuffle: %s, fletcher32: %s" % (dtype, chunks, chunk_cache, compression_opts, shuffle, fletcher32))
	loompy.create_from_loom(infile, outfile, chunks, chunk_cache, dtype, compression_opts, shuffle, fletcher32)

def benchmark_command(infile):
	logging.info("Benchmarking %s random row access, 10 x 100" % infile)
	setup = 'gc.enable(); import loompy; import random; ds = loompy.connect("%s", "r"); rmax = ds.file["matrix"].shape[0]-1' % infile
	testfunc = 'for i in range(0,100): ds[random.randint(0, rmax),:]'
	t = timeit.Timer(testfunc, setup)
	logging.info(t.timeit(10))

	logging.info("Benchmarking %s loading 100 rows at once, 10 x" % infile)
	setup = 'gc.enable(); import loompy; import random; ds = loompy.connect("%s", "r")' % infile
	testfunc = 'ds[0:100,:]'
	t = timeit.Timer(testfunc, setup)
	logging.info(t.timeit(10))

	logging.info("Benchmarking %s sequential row access, 10 x 100" % infile)
	setup = 'gc.enable(); import loompy; import random; ds = loompy.connect("%s", "r")' % infile
	testfunc = 'for i in range(0,100): ds[i,:]'
	t = timeit.Timer(testfunc, setup)
	logging.info(t.timeit(10))

class Empty(object):
	pass

if __name__ == '__main__':
	def_dir = os.environ.get('LOOM_PATH')
	if def_dir == None:
		def_dir = os.path.join(os.path.expanduser("~"),"loom-datasets")

	# Handle the special case of no arguments, and create a fake args object with default settings
	if len(sys.argv) == 1:
		args = Empty()
		setattr(args, "debug", False)
		setattr(args, "dataset_path", def_dir)
		setattr(args, "port", 8003)
		setattr(args, "command", "server")
		setattr(args, "show_browser", True)
	else:
		parser = VerboseArgParser(description='Loom command-line tool.')
		parser.add_argument('--debug', action="store_true")
		parser.add_argument('--dataset-path', help="Path to datasets directory (default: %s)" % def_dir , default=def_dir)

		subparsers = parser.add_subparsers(title="subcommands", dest="command")

		# loom version
		version_parser = subparsers.add_parser('version', help="Print version")

		# loom server
		server_parser = subparsers.add_parser('server', help="Launch loom server (default command)")
		server_parser.add_argument('--show-browser', help="Automatically launch browser", action="store_true")
		server_parser.add_argument('-p','--port', help="Port", type=int, default=80)

		# loom tile
		tile_parser = subparsers.add_parser('tile', help="Precompute heatmap tiles")
		tile_parser.add_argument("file", help="Loom input file")

		# loom tile all within project
		tile_parser = subparsers.add_parser('tile-project', help="Precompute heatmap tiles for all loom files in a project")
		tile_parser.add_argument("project", help="Project directory name")

		# loom expand
		expand_parser = subparsers.add_parser('expand', help="Expand a loom file in the datasets folder to compressed pickle/json files for better server performance. Automatically searches through projects and expands all files with a matching name.")
		expand_parser.add_argument("file", help="Loom input file")
		expand_parser.add_argument("-t", "--truncate", help="Remove previously expanded files if present (False by default)", action='store_true')
		expand_parser.add_argument("-m", "--metadata", help="Expand metadata (False by default)", action='store_true')
		expand_parser.add_argument("-a", "--attributes", help="Expand attributes (False by default)", action='store_true')
		expand_parser.add_argument("-r", "--rows", help="Expand rows (False by default)", action='store_true')
		expand_parser.add_argument("-c", "--cols", help="Expand columns (False by default)", action='store_true')

		# loom expand all loom files in a project
		expand_parser = subparsers.add_parser('expand-project', help="Expand all loom files in given project of the datasets folder to compressed pickle/json files for better server performance.")
		expand_parser.add_argument("project", help="Project directory name")
		expand_parser.add_argument("-t", "--truncate", help="Remove previously expanded files if present (False by default)", action='store_true')
		expand_parser.add_argument("-m", "--metadata", help="Expand metadata (False by default)", action='store_true')
		expand_parser.add_argument("-a", "--attributes", help="Expand attributes (False by default)", action='store_true')
		expand_parser.add_argument("-r", "--rows", help="Expand rows (False by default)", action='store_true')
		expand_parser.add_argument("-c", "--cols", help="Expand columns (False by default)", action='store_true')

		# loom expand all datasets
		expand_all_parser = subparsers.add_parser('expand-all', help="Expand all loom files in the data folder for better server performance")
		expand_all_parser.add_argument("-t", "--truncate", help="Remove previously expanded files if present (False by default)", action='store_true')
		expand_all_parser.add_argument("-m", "--metadata", help="Expand metadata (False by default)", action='store_true')
		expand_all_parser.add_argument("-a", "--attributes", help="Expand attributes (False by default)", action='store_true')
		expand_all_parser.add_argument("-r", "--rows", help="Expand rows (False by default)", action='store_true')
		expand_all_parser.add_argument("-c", "--cols", help="Expand columns (False by default)", action='store_true')


		# loom from-loom
		loom_parser = subparsers.add_parser('from-loom', help="Create a loom file from another loom file, letting you change HDF5 settings in the process. Useful to test the effect of various settings on performance")
		loom_parser.add_argument('-o','--outfile', help="Name of output file", required=True)
		loom_parser.add_argument('-i','--infile', help="Name of input loom file", required=True)
		loom_parser.add_argument('--project', help="Project name")
		loom_parser.add_argument('--dtype', help='Matrix data type. Defaults to "float32"', default="float32")
		loom_parser.add_argument('--chunks', help='Chunk tile size, i.e. "--chunks 10" (defaults to 64)', type=int, default=64)
		loom_parser.add_argument('--chunk_cache', help="Chunk cache size in MB (defaults to 512)", type=int, default=512)
		loom_parser.add_argument('--compression_opts', help='Gzip compression strength. Default: 4', type=int, default=4)
		loom_parser.add_argument('--shuffle', help='Use shuffle filter on chunks, defaults to False', action='store_true')
		loom_parser.add_argument('--fletcher32', help='Use fletcher32 checksum on chunks, defaults to False', action='store_true')

		# loom benchmark
		benchmark_parser = subparsers.add_parser('benchmark', help="Benchmark random row access. Can be useful when testing the effects of  various chunk tile size and cache size settings")
		benchmark_parser.add_argument('-i','--infile', help="Name of input loom file", required=True)

		args = parser.parse_args()

	if args.debug:
		logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
	else:
		logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

	if not os.path.exists(args.dataset_path):
		logging.info("Creating dataset directory: " + args.dataset_path)
		os.mkdir(args.dataset_path)

	if args.command == "version":
		print("loom v" + str(loompy.__version__))
		sys.exit(0)
	elif args.command == "server":
		start_server(args.dataset_path, args.show_browser, args.port, args.debug)
	elif args.command == "tile":
		tile_command(args.dataset_path, args.file)
	elif args.command == "tile-project":
		tile_project_command(args.dataset_path, args.project)
	elif args.command == "expand":
		expand_command(args.dataset_path, args.file, args.truncate, args.metadata, args.attributes, args.rows, args.cols)
	elif args.command == "expand-project":
		expand_project_command(args.dataset_path, args.project, args.truncate, args.metadata, args.attributes, args.rows, args.cols)
	elif args.command == "expand-all":
		expand_all_command(args.dataset_path, args.truncate, args.metadata, args.attributes, args.rows, args.cols)
	elif args.command == "from-loom":
		fromloom_command(args.dataset_path, args.infile, args.outfile, args.project, args.dtype, args.chunks, args.chunk_cache, args.compression_opts, args.shuffle, args.fletcher32)
	elif args.command == "benchmark":
		benchmark_command(args.infile)