#!python

import io
import re
from os import environ
from os.path import join, isfile, realpath, expanduser
from gnureadline import parse_and_bind, set_completer, read_history_file, write_history_file
from argparse import ArgumentParser
from contextlib import contextmanager
from signal import signal, SIGINT, SIGALRM, setitimer, ITIMER_REAL
from textwrap import dedent
from glob import glob
from functools import partial

from lark import Lark
from lark.visitors import Interpreter

from gpt4all_cli import *

REVISION:str|None
try:
  REVISION=environ["GPT4ALLCLI_REVISION"]
except Exception:
  try:
    from subprocess import check_output, DEVNULL
    REVISION=check_output(['git', 'rev-parse', 'HEAD'],
                          cwd=environ['GPT4ALLCLI_ROOT'],
                          stderr=DEVNULL).decode().strip()
  except Exception:
    try:
      from gpt4all_cli.revision import REVISION as __rv__
      REVISION = __rv__
    except ImportError:
      REVISION = None

def _get_cmd_prefix():
  prefices = list({c[0] for c in COMMANDS if len(c)>0})
  assert len(prefices) == 1
  return prefices[0]

CMDPREFIX = _get_cmd_prefix()

ARG_PARSER = ArgumentParser(description="Command-line arguments")
ARG_PARSER.add_argument(
  "--model-dir",
  type=str,
  help="Model directory to prepend to model file names",
  default=None
)
ARG_PARSER.add_argument(
  "--model", "-m",
  type=str,
  help="Model to use. STR1 is 'gpt4all' (the default) or 'openai'. STR2 is the model name",
  metavar="[STR1:]STR2",
  # default="mistral-7b-instruct-v0.1.Q4_0.gguf",
  # default='/home/grwlf/.local/share/nomic.ai/GPT4All/Meta-Llama-3-8B-Instruct.Q4_0.gguf'
  default=None
)
ARG_PARSER.add_argument(
  "--num-threads", "-t",
  type=int,
  help="Number of threads to use",
  default=None
)
ARG_PARSER.add_argument(
  '--model-apikey',
  type=str,
  help="Model provider-specific API key",
  metavar="STR",
  default=None
)
ARG_PARSER.add_argument(
  "--model-temperature",
  type=float,
  help="Temperature parameter of the model",
  default=None
)
ARG_PARSER.add_argument(
  "--device", "-d",
  type=str,
  help="Device to use for chatbot, e.g. gpu, amd, nvidia, intel. Defaults to CPU",
  default=None
)
ARG_PARSER.add_argument(
  "--readline-key-send",
  type=str,
  help="Terminal code to treat as Ctrl+Enter (default: \\C-k)",
  default="\\C-k"
)
ARG_PARSER.add_argument(
  '--readline-prompt',
  type=str,
  help="Input prompt (default: >>>)",
  default=">>> "
)
HISTORY_DEF="_gpt4all_cli_history"
ARG_PARSER.add_argument(
  '--readline-history',
  type=str,
  metavar='FILE',
  help=f"History file name (default is '{HISTORY_DEF}'; set empty to disable)",
  default=HISTORY_DEF
)
ARG_PARSER.add_argument(
  '--verbose',
  type=str,
  metavar='NUM',
  help="Set the verbosity level 0-no,1-full",
)
ARG_PARSER.add_argument(
  '--revision',
  action='store_true',
  help="Print the revision",
)



@contextmanager
def with_sigint(_handler):
  """ A Not very correct singal handler. One also needs to mask signals during switching handlers """
  prev=signal(SIGINT,_handler)
  try:
    yield
  finally:
    signal(SIGINT,prev)

def as_float(val:str, default:float|None)->float|None:
  return float(val) if val not in {None,"def","default"} else default
def as_int(val:str, default:int|None)->int|None:
  return int(val) if val not in {None,"def","default"} else default

def ask1(args, model:ModelProvider, message:str) -> str|None:
  opt:Options = Options(verbose=as_int(args.verbose, 0))
  response = io.StringIO()
  try:
    def _signal_handler(signum,frame):
      model.interrupt()
      print("\n<Keyboard interrupt>")

    model.set_temperature(as_float(args.model_temperature,None))
    response_generator = model.stream(message, opt=opt)

    with with_sigint(_signal_handler):
      for token in response_generator:
        print(token, end='', flush=True)
        response.write(token)

  finally:
    response.close()
    print()
  return response

def print_help():
  print(f"Commands: {' '.join(COMMANDS)}")

def print_aux(args, s:str)->None:
  print(f"INFO: {s}")

def print_aux_err(args, s:str)->None:
  print(f"ERROR: {s}")

@contextmanager
def with_chat_session(model:ModelProvider):
  if model is None:
    yield
  else:
    with model.with_chat_session():
      yield

def model_locations(args, model)->list[str]:
  for path in [model] + ([join(args.model_dir,model)] if args.model_dir else []):
    for match in glob(expanduser(path)):
      yield realpath(match)

def parse_modelname(args) -> tuple[callable,str]:
  pair = args.model.split(':')
  if len(pair) == 1:
    return partial(GPT4AllModelProvider, device=args.device), pair[0]
  elif len(pair) == 2:
    if pair[0] == 'gpt4all':
      return partial(GPT4AllModelProvider, device=args.device), pair[1]
    elif pair[0] == 'openai':
      return partial(OpenAIModelProvider, api_key=args.model_apikey), pair[1]
    else:
      raise ValueError(f"Invalid model provider '{pair[0]}'")
  else:
    raise ValueError(f"Invalid model name '{pair}'")

def main(cmdline=None):
  args = ARG_PARSER.parse_args(cmdline)
  if args.revision:
    print(REVISION)
    return 0
  parse_and_bind('tab: complete')
  parse_and_bind(f'"{args.readline_key_send}": "{CMD_ASK}\n"')
  hint = args.readline_key_send.replace('\\', '')

  print(f"Type /help or a question followed by the /ask command (or by pressing "
        f"`{hint}` key).")
  old_model = None
  model = None

  def _apply():
    nonlocal model, old_model
    if old_model != args.model:
      if args.model is not None:
        ctr, modelname = parse_modelname(args)
        model = ctr(modelname)
        print_aux(args, "Model is now initialized")
      else:
        if model:
          print_aux(args, "Model is now destroyed")
          del model
        model = None
      old_model = args.model
    if model is not None:
      num_threads = as_int(args.num_threads, None)
      old_num_threads = model.get_thread_count()
      if num_threads is not None:
        model.set_thread_count(num_threads)
        new_num_threads = model.get_thread_count()
        if old_num_threads != new_num_threads:
          print_aux(args, f"Num threads is now {new_num_threads}")

  class Repl(Interpreter):
    def __init__(self):
      self._reset()
    def _reset(self):
      self.in_echo = False
      self.message = ""
      self.exit_request = False
      self.reset_request = False
    def reset(self):
      old_message = self.message
      self._reset()
      if len(old_message)>0:
        print_aux(args, "Message buffer is now empty")
    def _finish_echo(self):
      if self.in_echo:
        print()
      self.in_echo = False
    def command(self, tree):
      nonlocal model
      self._finish_echo()
      command = tree.children[0].value
      if command == CMD_ECHO:
        self.in_echo = True
      elif command == CMD_ASK:
        if model is None:
          raise RuntimeError("No model is active, use /model first")
        if len(self.message.strip()) == 0:
          raise RuntimeError("Empty message buffer, write something first")
        else:
          ask1(args, model, self.message)
          self.message = ""
      elif command == CMD_HELP:
        print_help()
      elif command == CMD_EXIT:
        self.exit_request = True
      elif command == CMD_MODEL:
        matched = False
        argument = tree.children[2].children[0].value[1:-1]
        for model in model_locations(args, argument):
          if isfile(model):
            args.model = model
            self.reset_request = True
            matched = True
            break
        if not matched:
          args.model = argument if len(argument)>0 else None
        print_aux(args, f"Model will be set to '{args.model}'")
        self.reset_request = True
      elif command == CMD_NTHREADS:
        args.num_threads = tree.children[2].children[0].value
        print_aux(args, f"Num threads will be set to '{args.num_threads}'")
        self.reset_request = True
      elif command == CMD_TEMP:
        args.model_temperature = tree.children[2].children[0].value
        print_aux(args, f"Temperature will be set to '{args.model_temperature}'")
      elif command == CMD_RESET:
        print_aux(args, "Message buffer will be cleared")
        self.reset_request = True
      elif command == CMD_APIKEY:
        argument = tree.children[2].children[0].value[1:-1]
        print_aux(args, f"Model API key will be set to \"{argument}\"")
        args.model_apikey = argument
        self.reset_request = True
      elif command == CMD_VERBOSE:
        args.verbose = tree.children[2].children[0].value
        print_aux(args, f"Verbosity will be set to '{args.verbose}'")
      else:
        raise ValueError(f"Unknown command: {command}")
    def text(self, tree):
      text = tree.children[0].value
      if self.in_echo:
        print(text, end='')
      else:
        for cmd in COMMANDS:
          if cmd in text:
            print_aux(args, f"Warning: '{cmd}' was parsed as a text")
        self.message += text
    def escape(self, tree):
      text = tree.children[0].value[1:]
      if self.in_echo:
        print(text, end='')
      else:
        self.message += text

  if args.readline_history:
    try:
      read_history_file(args.readline_history)
      print_aux(args, f"History file loaded")
    except FileNotFoundError:
      print_aux(args, f"History file not loaded")
  else:
    print_aux(args, f"History file is not used")

  repl = Repl()
  while not repl.exit_request:
    _apply()
    repl.reset()
    with with_chat_session(model):
      try:
        while all([not repl.exit_request, not repl.reset_request]):
          repl.visit(PARSER.parse(input(args.readline_prompt)))
          repl._finish_echo()
          if args.readline_history:
            write_history_file(args.readline_history)
      except (ValueError,RuntimeError) as err:
        print_aux_err(args, err)
      except EOFError:
        print()
        break

if __name__ == "__main__":
  main()
