#!/usr/bin/env python3

"""
Command line utility to run SQL using lava database connectors.

Relies on the lava connector subsystem.

"""

from __future__ import annotations

import csv
import json
import logging
import os
import sys
from abc import ABC, abstractmethod
from argparse import ArgumentParser, Namespace
from collections.abc import Iterator

import boto3
import jinja2
import sqlparse
from smart_open import open  # noqa A004

from lava.connection import get_pysql_connection
from lava.lavacore import LOGNAME
from lava.lib.db import begin_transaction
from lava.lib.logging import setup_logging
from lava.lib.misc import json_default
from lava.version import __version__

try:
    import pyarrow as pa
    import pyarrow.parquet as pq
except ImportError:
    pa = None
    pq = None

__author__ = 'Murray Andrews'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]
LOG = logging.getLogger(LOGNAME)
LOGLEVEL = 'info'
BATCH_SIZE = 1024

# ..............................................................................
# region data writers
# ..............................................................................


# ------------------------------------------------------------------------------
class DataWriter(ABC):
    """
    Abstract base class for a data writer.

    This is also a context manager.

    ..warning::
        There is obvious logical flaw here in that it is possible for there to
        be multiple queries written to the same stream which will end up in a
        mess. So don't do that.

    :param filename:    Write output to the specified file. If None, use stdout.
    :param format_args: A namespace containing format specific args. These come
                        from CLI args.
    """

    FORMATTERS: dict[str, type[DataWriter]] = {}
    MODE = 'wt'  # File writing mode for output

    # --------------------------------------------------------------------------
    @classmethod
    def formatter(cls, fmt: str):
        """
        Register format handler classes.

        Usage:
            @DataWriter.formatter(format)
            a_class(...)

        :param fmt:     Format label. If None, the handler is not registered.
        """

        def decorate(handler) -> type[DataWriter]:
            """
            Register the handler cls.

            :param handler: Class to register.
            :return:        Unmodified class.

            """

            if fmt in cls.FORMATTERS:
                raise Exception(f'{fmt} is already registered')
            if fmt:
                cls.FORMATTERS[fmt] = handler
            return handler

        return decorate

    # --------------------------------------------------------------------------
    def __init__(self, filename: str, format_args: Namespace, s3client=None):
        """Init."""

        self.out_fp = None
        self.filename = filename
        self.format_args = format_args
        self.s3client = s3client or boto3.Session().client('s3')

    # --------------------------------------------------------------------------
    def __enter__(self) -> DataWriter:
        """Context manager open."""

        self.out_fp = (
            open(self.filename, self.MODE, transport_params={'client': self.s3client})
            if self.filename
            else sys.stdout
        )
        return self

    # --------------------------------------------------------------------------
    def __exit__(self, *args, **kwargs):
        """Context manager exit."""

        if self.filename and self.out_fp:
            self.out_fp.close()

    # --------------------------------------------------------------------------
    @classmethod  # noqa B027
    def cli_args(cls, argp: ArgumentParser):
        """Add format specific CLI arguments."""

        pass

    # --------------------------------------------------------------------------
    @abstractmethod
    def write(self, cursor, batch_size: int = BATCH_SIZE, header: bool = False) -> int:
        """
        Write out query results.

        :param cursor:      A database cursor on which a query has been executed.
        :param batch_size:  Number of rows to fetch at a time.
        :param header:      If True, produce a header row for SELECT queries.
                            Default is False.

        :return:            Number of data rows written.
        """

        raise NotImplementedError('write')


# ------------------------------------------------------------------------------
@DataWriter.formatter('csv')
class CsvWriter(DataWriter):
    """CSV writer."""

    DEFAULT_DIALECT = 'excel'

    # ------------------------------------------------------------------------------
    def __init__(self, *args, **kwargs):
        """Initialise CSV writer."""

        super().__init__(*args, **kwargs)

        try:
            self.format_args.quoting = getattr(csv, 'QUOTE_' + self.format_args.quoting.upper())
        except AttributeError:
            raise Exception(f'Bad quoting style: {self.format_args.quoting}')

        self.writer = None

    # --------------------------------------------------------------------------
    def __enter__(self) -> DataWriter:
        """Context manager open."""

        self.out_fp = (
            open(self.filename, self.MODE, transport_params={'client': self.s3client})
            if self.filename
            else sys.stdout
        )
        self.writer = csv.writer(
            self.out_fp,
            dialect=self.format_args.dialect,
            delimiter=self.format_args.delimiter,
            doublequote=self.format_args.doublequote,
            escapechar=self.format_args.escapechar,
            quotechar=self.format_args.quotechar,
            quoting=self.format_args.quoting,
        )

        return self

    # ------------------------------------------------------------------------------
    @classmethod
    def cli_args(cls, argp: Namespace):
        """Add format specific CLI arguments."""

        grp = argp.add_argument_group('CSV format arguments')

        grp.add_argument(
            '--delimiter',
            metavar='CHAR',
            action='store',
            default='|',
            help='Single character field delimiter. Default |.',
        )

        grp.add_argument(
            '--dialect',
            action='store',
            default=cls.DEFAULT_DIALECT,
            choices=csv.list_dialects(),
            help=f'CSV dialect (as per the Python csv module). Default is {cls.DEFAULT_DIALECT}.',
        )

        grp.add_argument(
            '--doublequote',
            action='store_true',
            help='See Python csv.writer.',
        )

        grp.add_argument(
            '--escapechar',
            metavar='CHAR',
            action='store',
            help='See Python csv.writer. Escaping is disabled by default.',
        )

        grp.add_argument(
            '--quotechar',
            metavar='CHAR',
            action='store',
            default='"',
            help='See Python csv.writer. Default is ".',
        )

        grp.add_argument(
            '--quoting',
            action='store',
            default='minimal',
            help=(
                'As for csv.writer QUOTE_* parameters (without the QUOTE_ prefix).'
                ' Default is minimal (i.e. QUOTE_MINIMAL).'
            ),
        )

    # ------------------------------------------------------------------------------
    def write(self, cursor, batch_size: int = BATCH_SIZE, header: bool = False) -> int:
        """
        Write data from the cursor in CSV format.

        :param cursor:      A database cursor on which a query has been executed.
        :param batch_size:  Number of rows to fetch at a time.
        :param header:      If True, produce a header row for SELECT queries.
                            Default is False.

        :return:            Number of data rows written.
        """

        if not cursor.description:
            return 0

        if header:
            self.writer.writerow(col[0] for col in cursor.description)

        rows = 0
        while True:
            if not (batch := cursor.fetchmany(batch_size)):
                break

            self.writer.writerows(batch)
            rows += len(batch)

        return rows


# ------------------------------------------------------------------------------
@DataWriter.formatter('jsonl')
class JsonlWriter(DataWriter):
    """Writer producing one JSON object per line."""

    # ------------------------------------------------------------------------------
    @classmethod
    def cli_args(cls, argp: ArgumentParser):
        """Add format specific CLI arguments."""

        grp = argp.add_argument_group('JSONL format arguments')
        grp.add_argument('--sort-keys', action='store_true', help='Sort keys in JSON objects.')

    # ------------------------------------------------------------------------------
    def write(self, cursor, batch_size: int = BATCH_SIZE, header: bool = False) -> int:
        """
        Write data from the cursor in JSONL format.

        :param cursor:      A database cursor on which a query has been executed.
        :param batch_size:  Number of rows to fetch at a time.
        :param header:      Ignored for this format.

        :return:            Number of data rows written.
        """

        if not cursor.description:
            return 0

        column_names = tuple(c[0] for c in cursor.description)

        rows = 0
        while True:
            if not (batch := cursor.fetchmany(batch_size)):
                break

            for row in batch:
                json.dump(
                    dict(zip(column_names, row)),
                    self.out_fp,
                    sort_keys=self.format_args.sort_keys,
                    default=json_default,
                )
                print(file=self.out_fp)

            rows += len(batch)

        return rows


# ------------------------------------------------------------------------------
@DataWriter.formatter('html')
class HtmlWriter(DataWriter):
    """
    HTML table writer.

    Only the table HTML is output so this can be embedded into a larger HTML
    document.

    The process uses Jinja to take advantage of its HTML escaping.

    """

    HTML = """
        {%- autoescape true -%}
        <TABLE class="lava-sql">
            {% if header %}<THEAD>
                <TR>
                    {% for h in header -%}
                    <TH>{{ h }}</TH>
                    {% endfor %}
                </TR>
            </THEAD>{% endif %}
            <TBODY>
            {%- for row in data %}
                <TR>
                    {% for val in row -%}
                        <TD>{{ val }}</TD>
                    {% endfor %}
                </TR>
            {% endfor %}
            </TBODY>
        </TABLE>
        {%- endautoescape %}
    """

    # ------------------------------------------------------------------------------
    def _data_stream(self, cursor, batch_size) -> Iterator[list]:
        """
        Stream rows of data from a cursor.

        :param cursor:      A database cursor on which a query has been executed.
        :param batch_size:  Number of rows to fetch at a time.`

        :returnL            A row generator.
        """

        self._rows = 0

        while True:
            batch = cursor.fetchmany(batch_size)
            if not batch:
                break
            self._rows += len(batch)
            yield from batch

    # ------------------------------------------------------------------------------
    def write(self, cursor, batch_size: int = BATCH_SIZE, header: bool = False) -> int:
        """
        Write data from the cursor as an HTML table.

        :param cursor:      A database cursor on which a query has been executed.
        :param batch_size:  Number of rows to fetch at a time.
        :param header:      If True, produce a header row for SELECT queries.
                            Default is False.

        :return:            Number of data rows written.
        """

        if not cursor.description:
            return 0

        jinja2.Template(self.HTML).stream(
            header=[col[0] for col in cursor.description] if header else None,
            data=self._data_stream(cursor, batch_size),
        ).dump(self.out_fp)
        return self._rows


# ------------------------------------------------------------------------------
@DataWriter.formatter('parquet' if pa else None)
class ParquetWriter(DataWriter):
    """Parquet writer."""

    MODE = 'wb'  # Need binary for parquet

    # ------------------------------------------------------------------------------
    def write(self, cursor, batch_size: int = BATCH_SIZE, header: bool = False) -> int:
        """
        Write data from the cursor in Parquet format.

        This is a bit of a half-hearted implementation. It's quite difficult to
        handle schema inference in a predictable or consistent way, particularly
        with data sourced via a DBAPI 2 connector as the "standard" does not
        provide any consistency in how implementations signal type information
        in cursor.description. So we let pyarrow examine the first batch of
        records and have its own guess.

        :param cursor:      A database cursor on which a query has been executed.
        :param batch_size:  Number of rows to fetch at a time.
        :param header:      Ignored for this format.

        :return:            Number of data rows written.
        """

        if not cursor.description:
            return 0
        column_names = tuple(c[0] for c in cursor.description)

        # Get the first batch to intuit the schema
        batch = cursor.fetchmany(batch_size)
        if not batch:
            return 0
        # noinspection PyArgumentList
        pq_batch = pa.RecordBatch.from_pylist([dict(zip(column_names, row)) for row in batch])

        rows = 0
        with pq.ParquetWriter(self.out_fp, pq_batch.schema) as writer:
            while True:
                writer.write_batch(pq_batch)
                rows += len(batch)
                batch = cursor.fetchmany(batch_size)
                if not batch:
                    break
                # noinspection PyArgumentList
                pq_batch = pa.RecordBatch.from_pylist(
                    [dict(zip(column_names, row)) for row in batch]
                )
        return rows


# ..............................................................................
# endregion data writers
# ..............................................................................


# ------------------------------------------------------------------------------
def process_cli_args() -> Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = ArgumentParser(prog=PROG, description='Run SQL using lava database connections.')

    argp.add_argument('--profile', action='store', help='As for AWS CLI.')

    argp.add_argument('-v', '--version', action='version', version=__version__)

    argp.add_argument(
        '-a',
        '--app-name',
        '--application-name',
        action='store',
        metavar='NAME',
        help=(
            'Use the specified application name when connecting to the database.'
            ' Ignored for database types that don\'t support this concept.'
        ),
    )
    argp.add_argument(
        '-b',
        '--batch-size',
        type=int,
        action='store',
        default=BATCH_SIZE,
        help=(
            f'Number of records per batch when processing SELECT querues. Default is {BATCH_SIZE}.'
        ),
    )

    argp.add_argument(
        '--format',
        action='store',
        choices=DataWriter.FORMATTERS,
        default='csv',
        help='Output format. Default is csv.',
    )

    argp.add_argument(
        '--header',
        action='store_true',
        help='Print a header for SELECT queries (output format dependent).',
    )

    argp.add_argument(
        '-o',
        '--output',
        metavar='FILENAME',
        action='store',
        help=(
            'Write output to the specified file which may be local or in S3'
            ' (s3://...). If not specified, output is written to stdout.'
        ),
    )

    argp.add_argument(
        '--raw',
        action='store_true',
        help=(
            'Don\'t split SQL source files into individual statements.'
            ' By default, an attempt will be made to split each source file into individual'
            ' SQL statements.'
        ),
    )

    argp.add_argument(
        '--transaction',
        action='store_true',
        help='Disable auto-commit and run all SQLs in a transaction.',
    )

    argp.add_argument(
        'sql_files',
        metavar='SQL-FILE',
        nargs='*',
        default=['-'],
        help=(
            'SQL files. These can be local or in S3 (s3://...).'
            ' If not specified or "-", stdin is used.'
        ),
    )

    # ------------------------------
    lavap = argp.add_argument_group('lava arguments')

    lavap.add_argument(
        '-c',
        '--conn-id',
        dest='conn_id',
        required=True,
        action='store',
        help='Lava database connection ID. Required.',
    )

    lavap.add_argument(
        '-r',
        '--realm',
        action='store',
        help=(
            'Lava realm name. If not specified, the environment variable LAVA_REALM must be set.'
        ),
    )

    # ------------------------------
    logp = argp.add_argument_group('logging arguments')
    logp.add_argument(
        '--no-colour',
        '--no-color',
        dest='colour',
        action='store_false',
        default=True,
        help='Don\'t use colour in information messages.',
    )

    logp.add_argument(
        '-l',
        '--level',
        metavar='LEVEL',
        default=LOGLEVEL,
        help=(
            'Print messages of a given severity level or above. The standard logging level names'
            ' are available but debug, info, warning and error are most useful.'
            f' The Default is {LOGLEVEL}.'
        ),
    )

    logp.add_argument(
        '--log',
        action='store',
        help='Log to the specified target. This can be either a file'
        ' name or a syslog facility with an @ prefix (e.g. @local0).',
    )

    logp.add_argument(
        '--tag',
        action='store',
        default=PROG,
        help=f'Tag log entries with the specified value. The default is {PROG}.',
    )

    # ------------------------------
    # Add format specific options
    for fmt in sorted(DataWriter.FORMATTERS):
        DataWriter.FORMATTERS[fmt].cli_args(argp)

    # ------------------------------
    args = argp.parse_args()

    if args.batch_size < 1:
        argp.error('Batch size must be >= 1.')

    if not args.realm:
        try:
            args.realm = os.environ['LAVA_REALM']
        except KeyError:
            argp.error(
                'Lava realm must be specified via -r, --realm or LAVA_REALM environment variable.'
            )

    return args


# ------------------------------------------------------------------------------
def do_sql_file(cursor, sql_file: str, args: Namespace, writer: DataWriter, s3) -> tuple[int, int]:
    """
    Run the SQL statements in the specified file.

    :param cursor:      A database cursor.
    :param sql_file:    SQL source file. An empty value and '-' mean stdin.
    :param args:        A namespace containing all the CLI args which will
                        include the format specific ones.
    :param writer:      A data writer.
    :param s3:          Boto3 s3 client.

    :return:            A tuple of ints (queries run, rows returned)
    """

    if args.batch_size < 1:
        raise ValueError(f'Bad batch size: {args.batch_size}')

    if sql_file and sql_file != '-':
        with open(sql_file, transport_params={'client': s3}) as sql_fp:
            raw_sql = sql_fp.read()
    else:
        raw_sql = sys.stdin.read()

    LOG.debug('Raw SQL: %s: %s', sql_file, raw_sql)

    # Split up the payload into statements unless raw
    sql_statements = [raw_sql] if args.raw else sqlparse.split(raw_sql.strip())

    n = 0
    rows = 0
    for n, sql in enumerate(sql_statements, start=1):
        LOG.debug('Preparing SQL %d: %s', n, sql)

        sql = sql.strip().rstrip(';')
        if not sql:
            LOG.debug('Skipping empty SQL statement # %d', n)
            continue

        cursor.execute(sql)
        LOG.debug('Cursor rowcount=%d, description=%s', cursor.rowcount, cursor.description)
        if not cursor.description:
            continue
        rows += writer.write(cursor, batch_size=args.batch_size, header=args.header)

    return n, rows


# ------------------------------------------------------------------------------
def main() -> int:
    """
    Do the business.

    :return:    status
    """

    setup_logging(LOGLEVEL, name=LOGNAME, prefix=PROG)
    args = process_cli_args()
    setup_logging(args.level, name=LOGNAME, target=args.log, colour=args.colour, prefix=args.tag)

    aws_session = boto3.Session(profile_name=args.profile)
    s3 = aws_session.client('s3')
    conn = get_pysql_connection(
        args.conn_id,
        args.realm,
        autocommit=not args.transaction,
        aws_session=aws_session,
        application_name=args.app_name,
    )
    try:
        LOG.debug('Autocommit set to %s', conn.autocommit)
    except AttributeError:
        LOG.debug('Autocommit not supported for connector %s', args.conn_id)

    cursor = conn.cursor()

    sql_file = None
    writer_class = DataWriter.FORMATTERS[args.format]
    with writer_class(args.output, args, s3client=s3) as writer:
        if args.transaction:
            LOG.info('Starting transaction')
            begin_transaction(conn, cursor)

        try:
            for sql_file in args.sql_files:
                queries, rows = do_sql_file(cursor, sql_file, args, writer, s3)
                qq = 'query' if queries == 1 else 'queries'
                rr = 'row' if rows == 1 else 'rows'
                LOG.info(f'{sql_file}: {queries} {qq} completed, {rows} {rr} returned')
        except Exception as e:
            if args.transaction:
                conn.rollback()
                if sql_file:
                    raise Exception(f'{sql_file}: Rollback : {e}')
            else:
                raise
        else:
            if args.transaction:
                LOG.info('Committing transaction')
                conn.commit()
        finally:
            conn.close()
            pass

    return 0


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        LOG.error(str(ex))
        exit(1)
    except KeyboardInterrupt:
        print('Interrupt', file=sys.stderr)
        exit(2)
