#!/usr/bin/env python

from mindexer.utils.sampling import SampleEstimator
from mindexer.utils.mongodb import MongoCollection
from mindexer.utils.query import Query

from itertools import permutations
import pandas as pd
import numpy as np

import sys
import time
import argparse

# TODO: Figure out why line 210 causes this warning and how we can prevent it
np.warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
pd.options.mode.chained_assignment = None

SYSTEM_PROFILE = "system.profile"
DEFAULT_SAMPLE_DB = "mindexer_samples"
DEFAULT_SAMPLE_RATIO = 0.001
MIN_SAMPLE_SIZE = 1000
MAX_INDEX_FIELDS = 3
IXSCAN_COST = 0.5
FETCH_COST = 10
SORT_COST = 10


def main(
    uri,
    db,
    collection,
    sample_ratio=0.001,
    sample_db=DEFAULT_SAMPLE_DB,
    verbose=False,
    before_and_after=False,
):
    namespace = f"{db}.{collection}"
    mcollection = MongoCollection(uri, db, collection)
    database = mcollection.db

    # -- Workload
    print(f"\n>> scanning system.profile collection for queries on {namespace}\n")
    # find all queries in system.profile related to the collection
    profile_collection = database[SYSTEM_PROFILE]

    # TODO include aggregate commands as well
    profile_docs = [
        doc["command"]
        for doc in profile_collection.find({"ns": namespace, "op": "query"})
    ]

    # extract MQL queries
    workload = []

    for doc in profile_docs:
        try:
            query = Query.from_mql(doc["filter"])
            if "limit" in doc:
                query.limit = doc["limit"]
            if "sort" in doc:
                query.sort = tuple(doc["sort"].keys())
            if "projection" in doc:
                # ignore projection keys with 0, specifically {_id: 0}
                query.projection = tuple([k for k, v in doc["projection"].items() if v == 1])

            workload.append(query)
        except Exception as e:
            print(f"    Warning: skipping query {doc['filter']}. {e}")

    print(f"\n>> found {len(workload)} queries for namespace {namespace}\n")
    if verbose:
        for i, query in enumerate(workload):
            print(f"{i:>4} {query}")

    # -- Sample Estimator

    assert sample_db != db, "Sample database cannot be the same as original database."
    if mcollection.count * sample_ratio <= MIN_SAMPLE_SIZE:
        estimator = SampleEstimator(
            mcollection,
            sample_size=min(mcollection.count, 1000),
            sample_db_name=sample_db,
            persist=True,
        )
    else:
        estimator = SampleEstimator(
            mcollection,
            sample_ratio=sample_ratio,
            sample_db_name=sample_db,
            persist=True,
        )

    print(f"\n>> extracted data sample, persisted at {sample_db}.{collection}\n")

    # -- generate list of index candidates

    candidates = set()
    for query in workload:
        num_preds = len(query.fields)
        # only consider indexes with at most MAX_INDEX_FIELDS fields
        for i in range(min(num_preds, MAX_INDEX_FIELDS)):
            for candidate in permutations(query.fields, i + 1):
                candidates.add(candidate)
    # We can discard the singular index on "_id", as this is always present in any collection
    candidates.discard(("_id",))

    print(f"\n>> generated {len(candidates)} candidate indexes\n")
    if verbose:
        for ic, candidate in enumerate(candidates):
            print(f"    {ic}   {candidate}")

    # -- score index candidates
    estimate_cache = {}

    def get_estimate(query):
        if query in estimate_cache:
            return estimate_cache[query]

        # -- estimated cardinalities with model
        est = estimator.estimate(query)

        estimate_cache[query] = est
        return est

    score_time = time.time()

    print("\n>> evaluating scores for index candidates\n")
    scores = pd.DataFrame(
        0,
        index=range(len(workload)),
        columns=list(candidates),
    )

    for nq, query in enumerate(workload):
        print(f"    query #{nq:<2}: {query}")
        for nc, candidate in enumerate(candidates):

            # score index for filtering
            fetch_query = query.index_intersect(candidate)
            if len(fetch_query.filter.keys()) == 0:
                # index can't be used, no benefit over collection scan
                benefit = 0
            else:
                # different costs per unit of work depending if the query is covered or not
                if query.is_covered(candidate):
                    cost = IXSCAN_COST
                else:
                    cost = FETCH_COST

                # estimate for number of "work units"
                est = get_estimate(fetch_query)

                # if the query has a limit, and all fields of the filter are in the
                # index, then we can cap the upper bound of the estimate at limit.
                # if not, then the expected number of units of work to find all matches
                # is equal to est. The expected case is equal to the worst case.
                # See https://math.stackexchange.com/questions/2595408/hypergeometric-distribution-expected-number-of-draws-until-k-successes-are-dra
                if query.limit is not None and query.is_subset(candidate):
                    est = min(query.limit, est)

                # benefit of the index over a collection scan (assuming cost of
                # collscan = 1.0 relative to other costs)
                benefit = estimator.get_cardinality() * 1.0 - est * cost

            # add additional benefit points if index can be used for sorting
            if query.can_use_sort(candidate):
                # cap at 1 to avoid log2(0), which is undefined
                est = max(1, get_estimate(query))
                benefit += est * np.log2(est) * SORT_COST

            scores.iat[nq, nc] = benefit

    score_duration_ms = (time.time() - score_time) * 1000
    if verbose:
        print(f"   took {score_duration_ms} ms.\n")

    def printScoreTable(scores):
        print("score table (rows=queries, columns=index candidates)")
        print(
            scores.rename(
                lambda c: list(candidates).index(c), axis="columns"
            ).reset_index(drop=True)
        )

    if verbose:
        printScoreTable(scores)

    # -- select indexes greedily

    estimator_indexes = []
    estimator_scores = []

    idx_scores = scores.copy()
    for i in range(len(candidates)):

        # --- sum scores of all queries
        topscore = idx_scores.sum(axis=0).sort_values(ascending=False)
        idx = topscore.index[0]

        # if nothing can be improved, we're done
        if topscore.values[0] <= 0:
            break

        # remove index from the score table
        idx_scores.drop(idx, axis="columns", inplace=True)

        estimator_indexes.append(idx)
        estimator_scores.append(topscore[0])

        ### update score matrix
        # for each query (row) and all created indexes (columns)
        # pick the maximum non-zero number and subtract from
        # the current score.
        for qi, query in enumerate(workload):
            # scores of existing indexes that can support this query (no 0s)
            # TODO: DeprecationWarning !!
            existing_scores = [
                s for s in scores[estimator_indexes].iloc[qi].tolist() if s != 0
            ]
            if len(existing_scores) == 0:
                # if no existing index can support this query, the current score remains
                continue
            best_existing = max(existing_scores)

            # new score is the difference between the best index so far and this index
            columns = idx_scores.columns
            idx_scores.loc[qi, columns] = scores.loc[qi, columns].subtract(
                best_existing, axis=0
            )

            # since an existing index exists for this query, we can't make it worse:
            # set negative values for this row (=query) to 0.
            idx_scores.iloc[qi].mask(idx_scores.iloc[qi] < 0, 0, inplace=True)

    print(f"\n>> dropping sample collection {sample_db}.{collection}\n")
    estimator.drop_sample()

    print(f"\n>> recommending {len(estimator_indexes)} index(es)\n")
    for idx in estimator_indexes:
        print("    ", dict(zip(idx, [1] * len(idx))))

    if before_and_after:
        print(
            "\n>> WARNING: The --before-and-after option will execute the query workload\n"
            + f"against your deployment on {uri}.\n"
            + "The workload runs once without the proposed indexes and once with the indexes.\n"
            + "Do not use this option on a production system!\n\nDo you wish to continue? [y/N]",
            end=" ",
            flush=True,
        )
        # raw_input returns the empty string for "enter"
        yes = {"yes", "y", "ye", ""}

        choice = input().lower()
        if not choice in yes:
            sys.exit()

        # run before creating (additional) indexes
        print("\n>> evaluating workload before creating new indexes\n")
        before_time = mcollection.execute_workload(workload)
        print(f"\n    workload took {before_time / 1000.:.2f} sec.")

        # create recommended indexes
        print("\n>> creating indexes\n")
        for idx in estimator_indexes:
            if verbose:
                print(f"    - {dict(zip(idx, [1] * len(idx)))}")
            mcollection.create_index(idx)

        # run with newly created indexes
        print("\n>> evaluating workload after creating new indexes\n")
        after_time = mcollection.execute_workload(workload)
        print(f"\n    workload took {after_time / 1000.:.2f} sec to execute.")


if __name__ == "__main__":
    # Instantiate the CLI argument parser
    parser = argparse.ArgumentParser(
        description="Experimental Index Recommendation Tool for MongoDB."
    )

    # URI, database and collection arguments
    parser.add_argument(
        "--uri",
        type=str,
        metavar="<uri>",
        help="mongodb uri connection string",
        required=True,
    )
    parser.add_argument(
        "-d", "--db", metavar="<db>", type=str, help="database name", required=True
    )
    parser.add_argument(
        "-c",
        "--collection",
        type=str,
        metavar="<coll>",
        help="collection name",
        required=True,
    )
    parser.add_argument(
        "--sample-ratio",
        type=float,
        default=0.001,
        metavar="<sr>",
        help="sample ratio (default=0.001)",
    )
    parser.add_argument(
        "--sample-db",
        type=str,
        default=DEFAULT_SAMPLE_DB,
        metavar="<db>",
        help="sample database name (default=mindexer_samples)",
    )

    parser.add_argument(
        "--before-and-after",
        action="store_true",
        help="executes and compares workload execution times without/with indexes.",
    )

    parser.add_argument("-v", "--verbose", action="store_true", help="verbose output")

    args = parser.parse_args()
    main(**vars(args))
