#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import atexit
import os
import re
import signal
import subprocess
import tempfile
import threading
from collections import defaultdict
from itertools import zip_longest
from pathlib import Path
from typing import List, Tuple


def parse_versions(versions):
	if not versions:
		raise argparse.ArgumentTypeError("Missing versions attribute")
	parsed_versions = []
	versions = versions.split(',')
	for version in versions:
		try:
			parsed_versions.append(tuple(int(num) for num in version.split('.')))
		except ValueError:
			raise argparse.ArgumentTypeError("Bad version format: %s" % version)
	return sorted(parsed_versions)


def get_output_and_args(compile_args):
	compile_args = compile_args[:]
	output_filename = os.path.splitext(compile_args[0].name)[0] + '.txt'

	out_arg_idx = None
	try:
		out_arg_idx = compile_args.index('-o')
	except ValueError:
		pass
	if out_arg_idx:
		if out_arg_idx < len(compile_args) - 1:
			output_filename = compile_args[out_arg_idx + 1]
			compile_args[out_arg_idx + 1] = '__OUTPUT__'
		else:
			compile_args.pop()
			out_arg_idx = None

	if out_arg_idx is None:
		compile_args.extend(['-o', '__OUTPUT__'])

	return output_filename, compile_args


def check_proc_status(proc):
	if proc.returncode:
		raise subprocess.CalledProcessError(proc.returncode, ' '.join(str(a) for a in proc.args))


def ensure_venv_exists(version_name, virtualenvs_dir):
	virtualenv_dir = virtualenvs_dir / f'py{version_name}'
	if not virtualenv_dir.exists():
		process_args = [f'python{version_name}', '-m', 'venv', str(virtualenv_dir)]
		proc = subprocess.Popen(process_args, preexec_fn=os.setpgrp)
		proc.wait()
		check_proc_status(proc)


class PipCompile(threading.Thread):
	def __init__(self, virtualenv_dir, compile_args, output_path):
		super().__init__()
		compile_args = [output_path if arg == '__OUTPUT__' else arg for arg in compile_args]
		self.virtualenv_dir = virtualenv_dir
		self.compile_args = compile_args
		self.requirements = ''
		self.output_path = output_path
		self.exc = None
		self.__stopping = False
		self.__proc = None

	@property
	def python_binary(self):
		return str(self.virtualenv_dir / 'bin' / 'python')

	@property
	def pip_binary(self):
		return str(self.virtualenv_dir / 'bin' / 'pip')

	@property
	def pip_compile_binary(self):
		return str(self.virtualenv_dir / 'bin' / 'pip-compile')

	def run(self):
		try:
			if self.__stopping:
				return
			self.__proc = subprocess.Popen([self.pip_binary, 'install', '--upgrade', 'pip-tools'], preexec_fn=os.setpgrp)
			self.__proc.communicate()
			if self.__stopping:
				return
			check_proc_status(self.__proc)
			self.__proc = None

			if self.__stopping:
				return
			self.__proc = subprocess.Popen([self.pip_compile_binary] + self.compile_args, preexec_fn=os.setpgrp)
			__, __ = self.__proc.communicate()
			if self.__stopping:
				return
			check_proc_status(self.__proc)
			with self.output_path.open('r') as fp:
				self.requirements = fp.read()
			self.__proc = None
		except BaseException as e:
			self.exc = e

	def stop(self):
		self.__stopping = True
		proc = self.__proc
		if proc is not None:
			try:
				group = os.getpgid(proc.pid)
			except Exception:
				group = None
			proc.terminate()
			try:
				proc.wait(0.1)
			except:
				proc.kill()
			if group is not None:
				try:
					os.killpg(group, signal.SIGKILL)
				except Exception:
					pass


def split_requirements_and_comments(requirements_txt: str) -> List[Tuple[str, str]]:
	data = re.split(r'((?:(?:\s*#[^\n]*)?\s+)+)', requirements_txt, flags=re.MULTILINE)
	return [
		(requirement.strip(), comment)
		for requirement, comment
		in zip_longest(data[0::2], data[1::2], fillvalue='')
		if requirement.strip()
	]


def versions_to_markers(versions, all_versions):
	indexes = [all_versions.index(v) for v in versions]

	ranges = []

	range_start = indexes[0]
	range_end = indexes[0]

	for index in indexes:
		if index - range_end > 1:
			ranges.append((range_start, range_end))
			range_start = index
			range_end = index
		else:
			range_end = index

	ranges.append((range_start, range_end))

	markers = []

	for range_start, range_end in ranges:
		marker = []
		if range_start != 0:
			marker.append(f'python_version >= "{all_versions[range_start]}"')
		if range_end != len(all_versions) - 1:
			marker.append(f'python_version < "{all_versions[range_end + 1]}"')
		markers.append(' and '.join(marker))

	markers = [c for c in markers if c]

	if len(markers) > 1:
		markers = [f'({c})' for c in markers if c]

	return ' or '.join(markers)


def pip_compile(versions, compile_args, virtualenvs_dir):
	processes = []
	first_inputs = []
	version_names = [".".join(str(v) for v in version) for version in versions]

	output_filename, compile_args = get_output_and_args(compile_args)

	tmp_dir = tempfile.TemporaryDirectory()
	tmp_dir_path = Path(tmp_dir.name)

	first_input = compile_args[0]
	first_input.seek(0)
	first_input_data = first_input.read()
	compile_args[0] = Path(first_input.name)
	first_input_replacement = first_input.name
	first_input.close()

	with open(output_filename, 'a+') as fp:
		fp.seek(0)
		current_saved_requirements = fp.read()

	for version_name in version_names:
		version_compile_args = compile_args[:] # copy
		version_compile_args[0] = tmp_dir_path /  f'py{version_name}.in'
		first_inputs.append(str(version_compile_args[0]))
		with (tmp_dir_path / f'py{version_name}.in').open('w') as fp:
			fp.write(first_input_data)
		output_path = tmp_dir_path / f'py{version_name}.txt'
		with output_path.open('w') as fp:
			fp.write(current_saved_requirements)

		virtualenv_dir = virtualenvs_dir / f'py{version_name}'
		ensure_venv_exists(version_name, virtualenvs_dir)
		compile_proc = PipCompile(virtualenv_dir, version_compile_args, output_path)
		compile_proc.start()
		processes.append(compile_proc)

	def at_exit(*args):
		for proc in processes:
			proc.stop()

	atexit.register(at_exit)
	signal.signal(signal.SIGTERM, at_exit)
	signal.signal(signal.SIGINT, at_exit)
	signal.signal(signal.SIGHUP, at_exit)

	requirements = []
	for proc, version in zip(processes, versions):
		proc.join()
		version_name = ".".join(str(v) for v in version)
		requirements.append(proc.requirements)
		if proc.exc:
			raise proc.exc

	version_specs = defaultdict(lambda: {'versions': [], 'markers': '', 'comment': ''})

	for version_name, first_tmp_input, version_requirements in zip(version_names, first_inputs, requirements):
		version_requirements = split_requirements_and_comments(version_requirements)
		for req, comment in version_requirements:
			try:
				req, markers = req.split(';', 1)
				markers = markers.strip()
			except ValueError:
				markers = ''
			req = req.strip()
			version_specs[req]['markers'] = markers
			version_specs[req]['versions'].append(version_name)
			version_specs[req]['comment'] = comment.replace(first_tmp_input, first_input_replacement)

	with open(output_filename, 'w') as fp:
		for spec, item in version_specs.items():
			fp.write(spec)
			markers = versions_to_markers(item['versions'], version_names)
			if item['markers'] and item['markers'] != markers:
				if markers:
					markers = f'({item["markers"]}) and ({markers})'
				else:
					markers = item["markers"]
			if markers:
				fp.write(f' ; {markers}')
			fp.write(item['comment'])


def main():
	parser = argparse.ArgumentParser(description="Compile requirements.in to single requirements.txt file with multiple python versions support")
	parser.add_argument('versions', type=parse_versions, help="Comma separated list of pythoon versions")
	parser.add_argument('--virtualenvs_dir', default=str(Path.home() / '.virtualenvs'), help="Where to create virtualenvs")
	parser.add_argument('input', type=argparse.FileType(mode='r'), help="First input file")
	parser.add_argument('pip_compile_args', nargs=argparse.REMAINDER, help="Arguments forwarded to pip-tools")
	args = parser.parse_args()
	compile_args = [args.input] + args.pip_compile_args
	pip_compile(args.versions, compile_args, Path(args.virtualenvs_dir))


if __name__ == "__main__":
	main()
