#!/usr/bin/env python3

"""Validate ingestion-sink support for sinking to *_live tables."""

from argparse import ArgumentParser
from datetime import datetime
from multiprocessing.pool import ThreadPool
from os.path import abspath, dirname
from textwrap import dedent
import logging
import sys
import warnings

from google.cloud import bigquery

# sys.path needs to be modified to enable package imports from parent
# and sibling directories. Also see:
# https://stackoverflow.com/questions/6323860/sibling-package-imports/23542795#23542795
sys.path.append(dirname(dirname(dirname(dirname(abspath(__file__))))))
from bigquery_etl.util import standard_args  # noqa E402


parser = ArgumentParser()
standard_args.add_dry_run(parser)
standard_args.add_log_level(parser)
standard_args.add_parallelism(parser)
standard_args.add_table_filter(parser)
parser.add_argument(
    "-d",
    "--date",
    required=True,
    type=lambda d: datetime.strptime(d, "%Y-%m-%d").date(),
    help="The day for which to compare data, in format 2019-01-01",
)


def comparable_schema(fields):
    return {
        field.name: {
            **{
                k: v
                for k, v in field.to_api_repr().items()
                # ignore description
                # skip fields because it is handled below
                # skip name because it is the dict key
                if k not in ("description", "fields", "name")
            },
            **({"fields": comparable_schema(field.fields)} if field.fields else {}),
        }
        for field in sorted(fields, key=lambda f: f.name)
    }


def comparable_schema_to_list_schema(schema):
    return [
        {
            "name": key,
            **value,
            **(
                {"fields": comparable_schema_to_list_schema(value["fields"])}
                if "fields" in value
                else {}
            ),
        }
        for key, value in schema.items()
    ]


def deep_del_matching_keys(*items):
    """Remove matching keys from a series of dicts."""
    for key in sorted({k for item in items for k in item}):
        if all(key in item for item in items):
            values = (item[key] for item in items)
            value = next(values)
            if all(value == v for v in values):
                for item in items:
                    del item[key]
                continue
            if all(isinstance(item[key], dict) for item in items):
                deep_del_matching_keys(*(item[key] for item in items))
        # remove empty items
        for item in items:
            if key in item and not item[key]:
                del item[key]


def lines(string):
    return [line + "\n" for line in string.split("\n")]


def human_readable(
    value, suffixes=[" bytes", "KiB", "MiB", "GiB", "TiB", "PiB"], step=1024
):
    if value < step and isinstance(value, int):
        return f"{value}{suffixes[0]}"
    for i, suffix in enumerate(suffixes):
        if value < step ** (i + 1):
            break
    return f"{value/step**i:,.2f}{suffix}"


def log_difference(table, prod_value, stage_value, case="in stage", **kwargs):
    logging.warning(
        f"{table}: expected {human_readable(abs(prod_value-stage_value), **kwargs)} "
        f"{'more' if prod_value > stage_value else 'less'} {case}"
    )


def compare(
    client, prod_sql_table_id, stage_sql_table_id, submission_date, pool, dry_run
):
    different = False
    table = stage_sql_table_id.split(".", 1)[-1]
    prod, stage = [
        client.get_table(f"{table}${submission_date:%Y%m%d}")
        for table in (prod_sql_table_id, stage_sql_table_id)
    ]
    if prod.num_rows != stage.num_rows:
        different = True
        log_difference(
            table,
            prod.num_rows,
            stage.num_rows,
            suffixes=[f"{s} rows" for s in ["", "K", "M", "G", "T", "P"]],
            step=1000,
        )
    elif prod.num_bytes != stage.num_bytes:
        different = True
        log_difference(table, prod.num_bytes, stage.num_bytes)
        diff = prod.num_bytes - stage.num_bytes
        prod_partition_bytes = stage_partition_bytes = 0
        partition_field = "submission_timestamp"
        prod_fields = set(f.name for f in prod.schema)
        stage_fields = set(f.name for f in stage.schema)
        # put partition_field first,
        # then environment,
        # then additional_properties,
        # then remaining fields in alphabetical order
        # this is to target known sources of differences first
        overlap_fields = sorted(
            prod_fields.intersection(stage_fields),
            key=lambda f: (
                f != partition_field,
                f != "environment",
                f != "additional_properties",
                f,
            ),
        )
        for field_name in overlap_fields:
            prod_bytes, stage_bytes = [
                client.query(
                    f"SELECT {field_name} FROM `{table}` "
                    f"WHERE DATE({partition_field}) = '{submission_date}'",
                    bigquery.QueryJobConfig(dry_run=True),
                ).total_bytes_processed
                for table in (prod_sql_table_id, stage_sql_table_id)
            ]
            if field_name == partition_field:
                prod_partition_bytes, stage_partition_bytes = prod_bytes, stage_bytes
            else:
                # don't double-count partition bytes
                prod_bytes -= prod_partition_bytes
                stage_bytes -= stage_partition_bytes
            if prod_bytes != stage_bytes:
                log_difference(
                    table, prod_bytes, stage_bytes, case=f"for {field_name} in stage"
                )
                diff -= prod_bytes - stage_bytes
                if diff == 0:
                    break
        else:
            logging.error(
                f"{table}: expected {human_readable(abs(diff))} "
                "less from fields not found in prod and stage"
            )
    prod_schema, stage_schema = [comparable_schema(t.schema) for t in (prod, stage)]
    if prod_schema != stage_schema:
        deep_del_matching_keys(prod_schema, stage_schema)
        prod_schema, stage_schema = [
            comparable_schema_to_list_schema(s) for s in (prod_schema, stage_schema)
        ]
        different = True
        if prod_schema:
            logging.warning(f"{table}: expected stage schema to include {prod_schema} ")
        if stage_schema:
            logging.warning(
                f"{table}: expected stage schema NOT to include {stage_schema}"
            )
    # if the table sizes are different, they probably contain �s that need to
    # be converted to ?s for before they are compared
    convert_replacement_character = different
    # crash_v4 contains fox emoji that need to be need to be converted to ??s
    convert_fox_emoji = prod_sql_table_id.endswith(".crash_v4")
    # main_v4 is too large to compare all at once, so run per sample_id queries
    by_sample_id = prod_sql_table_id.endswith(".main_v4")
    query = dedent(
        '''
        CREATE TEMP FUNCTION sort_json(json_string STRING)
        RETURNS STRING
        LANGUAGE js
        AS """
        var sortJson = function(obj) {
          if (Array.isArray(obj)) {
            let result = [];
            for (var value of obj) {
              if (value !== null && typeof value == 'object') {
                result.push(sortJson(value));
              } else {
                result.push(value);
              }
            }
            return result;
          } else {
            let result = {};
            for (var key of Object.keys(obj).sort()) {
              let value = obj[key];
              if (value !== null && typeof value == 'object') {
                result[key] = sortJson(value);
              } else {
                result[key] = value;
              }
            }
            return result;
          }
        };
        return JSON.stringify(sortJson(JSON.parse(json_string)));
        """;
        CREATE TEMP FUNCTION remove_null_and_empty_values(json_string STRING)
        RETURNS STRING
        LANGUAGE js
        AS """
        var removeNullAndEmptyValues = function(obj) {
          if (Array.isArray(obj)) {
            let result = [];
            for (var value of obj) {
              if (typeof value == 'object') {
                result.push(removeNullAndEmptyValues(value));
              } else {
                result.push(value);
              }
            }
            return result;
          } else {
            let result = {};
            for (var key of Object.keys(obj).sort()) {
              let value = obj[key];
              if (value === null) {
                // drop value
              } else if (typeof value == 'object') {
                let newValue = removeNullAndEmptyValues(value);
                if (Object.keys(newValue).length > 0) {
                  result[key] = newValue;
                }
              } else {
                result[key] = value;
              }
            }
            return result;
          }
        };
        return JSON.stringify(removeNullAndEmptyValues(JSON.parse(json_string)));
        """;
        '''
        f"""
        WITH stage AS (
          SELECT
            sort_json(additional_properties) AS additional_properties,
            document_id,
            remove_null_and_empty_values(
              {'' if convert_replacement_character else '-- '}REPLACE(
              {'' if convert_fox_emoji else '-- '}REPLACE(
                TO_JSON_STRING(
                  (SELECT AS STRUCT _.* EXCEPT (additional_properties, document_id))
                )
              {'' if convert_fox_emoji else '-- '},"🦊","??")
              {'' if convert_replacement_character else '-- '},"�","?")
            ) AS json_string,
          FROM
            {stage_sql_table_id} AS _
          WHERE
            DATE(submission_timestamp) = '{submission_date}'
            {'' if by_sample_id else '-- '}AND sample_id = @sample_id
        ),
        prod AS(
          SELECT
            sort_json(additional_properties) AS additional_properties,
            document_id,
            remove_null_and_empty_values(
              {'' if convert_replacement_character else '-- '}REPLACE(
              {'' if convert_fox_emoji else '-- '}REPLACE(
                TO_JSON_STRING(
                  (SELECT AS STRUCT _.* EXCEPT (additional_properties, document_id))
                )
              {'' if convert_fox_emoji else '-- '},"🦊","??")
              {'' if convert_replacement_character else '-- '},"�","?")
            ) AS json_string,
          FROM
            {prod_sql_table_id} AS _
          WHERE
            DATE(submission_timestamp) = '{submission_date}'
            {'' if by_sample_id else '-- '}AND sample_id = @sample_id
        ),
        joined AS (
          SELECT
            STRUCT(
              COALESCE(stage.json_string, 'null')
              != COALESCE(prod.json_string, 'null') AS mismatch,
              stage.json_string AS stage,
              prod.json_string AS prod
            ) AS json_string,
            STRUCT(
              COALESCE(stage.additional_properties, 'null')
              != COALESCE(prod.additional_properties, 'null') AS mismatch,
              stage.additional_properties AS stage,
              stage.additional_properties AS prod
            ) AS additional_properties,
          FROM
            stage
          FULL OUTER JOIN
            prod
          USING
            (document_id)
        )
        SELECT
          *
        FROM
          joined
        WHERE
          json_string.mismatch
          OR additional_properties.mismatch
        """
    ).rstrip()
    if dry_run:
        logging.debug(f"{table}: would compare individual rows via: {query}")
    else:
        if by_sample_id:
            differences = sum(
                pool.map(
                    lambda sample_id: len(
                        list(
                            client.query(
                                query,
                                bigquery.QueryJobConfig(
                                    query_parameters=[
                                        bigquery.ScalarQueryParameter(
                                            "sample_id", "INT64", sample_id
                                        )
                                    ]
                                ),
                            ).result()
                        )
                    ),
                    range(100),
                )
            )
        else:
            differences = len(list(client.query(query).result()))
        if differences > 0:
            logging.error(
                f"{table}: expected 0 differences excluding "
                + ("🦊, " if convert_fox_emoji else "")
                + ("� and " if convert_replacement_character else "")
                + f"null lists, but got {differences}"
            )
            return True
        else:
            if different:
                msg = "accounted for all differences by"
                log = logging.warning
            else:
                msg = "found no differences when"
                log = logging.info
            log(
                f"{table}: {msg} excluding "
                + ("🦊, " if convert_fox_emoji else "")
                + ("� and " if convert_replacement_character else "")
                + "null lists"
            )
            return False


def main():
    args = parser.parse_args()
    warnings.filterwarnings("ignore", module="google.auth._default")
    client = bigquery.Client()
    with ThreadPool(args.parallelism) as pool:
        results = pool.starmap(
            compare,
            [
                (
                    client,
                    f"{t.project}.{dataset_table}",
                    f"moz-fx-data-shar-nonprod-efed.{dataset_table}_k8s",
                    args.date,
                    pool,
                    args.dry_run,
                )
                for d in client.list_datasets("moz-fx-data-shared-prod")
                if d.dataset_id.endswith("_stable")
                for t in client.list_tables(d.reference)
                for dataset_table in [f"{t.dataset_id}.{t.table_id}"]
                if args.table_filter(dataset_table)
            ],
            chunksize=1,
        )
    if any(results):
        sys.exit(1)


if __name__ == "__main__":
    main()
