#!/usr/bin/env python3

import argparse
import json
import os
import sys
from threading import Lock

if __name__ == '__main__':  # noqa
    pkg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))  # noqa
    sys.path.insert(0, pkg_root)  # noqa

from hca.util.pool import ThreadPool

from dcp_diag.finders import Finder
from dcp_diag.api_agents import DataStoreAgent

V_SILENT = 0
V_SUMMARY = 1
V_BAD_DETAIL = 2
V_GOOD_DETAIL = 3

verbosity_level = V_SUMMARY


def output(message, message_verbosity=V_SILENT):
    if message_verbosity <= verbosity_level:
        sys.stdout.write(message)
        sys.stdout.flush()


class AnalyzeSubmission:

    """
    # Start with Submission ID, then check:
    # ✔︎    Primary bundles (known by Ingest)
    # ✔︎    Primary bundles in DSS/AWS (direct access)
    # ✔︎    Primary bundles in DSS/GCP (direct access)
    # ✔︎    Primary bundles searchable in DSS/AWS using Project ID
    # ✔︎    Primary bundles searchable in DSS/GCP using Project ID
    # ✔    Secondary bundles searchable in DSS/AWS using files.analysis_process_json.input_bundles
    # ✔    Secondary bundles searchable in DSS/GCP using files.analysis_process_json.input_bundles
    # todo  Secondary bundles searchable in DSS/AWS using Project ID
    # todo  Secondary bundles searchable in DSS/GCP using Project ID

    # Cache results so we can reload and update later?
    """

    class AnalysisState:

        def __init__(self):
            self.bundle_map = {}
            self.lock = Lock()

        @property
        def primary_bundle_count(self):
            return len([k for (k, v) in self.bundle_map.items() if v['type'] == 'primary'])

        def iter_primary_bundles(self):
            for uuid in [k for (k, v) in self.bundle_map.items() if v['type'] == 'primary']:
                yield uuid

        def save(self, filename):
            output(f"\tSaving state in {filename}...", V_SUMMARY)
            with open(filename, 'w') as fp:
                fp.write(json.dumps(self.bundle_map, indent=4))
            output("done.\n", V_SUMMARY)

        def load(self, filename):
            output(f"\tLoading state from {filename}...", V_SUMMARY)
            with open(filename, 'r') as fp:
                self.bundle_map = json.loads(fp.read())
            output("done\n", V_SUMMARY)

    class DSSBundlePresenceChecker:

        def __init__(self, deployment, state, options):
            self.deployment = deployment
            self.state = state
            self.options = options

            self.primary_bundle_count = self.state.primary_bundle_count
            self.checked_bundles = {'aws': 0, 'gcp': 0}

        def check(self):
            output("\tChecking for bundle manifests:", V_SUMMARY)
            pool = ThreadPool(self.options.jobs)
            for bundle_uuid in self.state.iter_primary_bundles():
                for replica in ['aws', 'gcp']:
                    replica_data = self.state.bundle_map[bundle_uuid][replica]
                    if 'dss_presence' in replica_data:
                        if replica_data['dss_presence']:
                            continue
                    else:
                        replica_data['dss_presence'] = {}
                    # self._check_bundle_manifest_exists(bundle_uuid, replica)
                    pool.add_task(self._check_bundle_manifest_exists, bundle_uuid, replica)
            pool.wait_for_completion()
            output("...done.\n", V_SUMMARY)

        def _check_bundle_manifest_exists(self, bundle_uuid, replica):
            dss = DataStoreAgent(self.deployment)
            try:
                manifest = dss.bundle_manifest(bundle_uuid, replica)
                with self.state.lock:
                    bundle_info = self.state.bundle_map[bundle_uuid]
                    bundle_info['fqid'] = ".".join([manifest['bundle']['uuid'], manifest['bundle']['version']])
                    bundle_info[replica]['dss_presence'] = True
            except AssertionError as e:
                with self.state.lock:
                    self.state.bundle_map[bundle_uuid][replica]['dss_presence'] = False
                    output(f"\rbundle {bundle_uuid} is missing from {replica.upper()}\n", V_BAD_DETAIL)
            self.checked_bundles[replica] += 1
            self._print_progress()

        def _print_progress(self):
            if sys.stdout.isatty():
                output(f"\r\tChecking for bundle manifests: "
                       f"AWS: {self.checked_bundles['aws']}/{self.primary_bundle_count}"
                       f" GCP: {self.checked_bundles['gcp']}/{self.primary_bundle_count}", V_SUMMARY)

        def print_results(self):
            for replica in ['aws', 'gcp']:
                present_bundles = {
                    k: v for (k, v) in self.state.bundle_map.items() if v[replica]['dss_presence']
                }
                absent_bundles = {
                    k: v for (k, v) in self.state.bundle_map.items() if not v[replica]['dss_presence']
                }
                output(f"\t{len(present_bundles)} bundle are present in {replica.upper()}\n", V_SUMMARY)
                if len(present_bundles) > 0 and verbosity_level >= V_GOOD_DETAIL:
                        for uuid, bundle_info in present_bundles.items():
                            print(f"\t{bundle_info.get('fqid', uuid)}")
                if len(absent_bundles) > 0 and verbosity_level >= V_BAD_DETAIL:
                        output(f"\t{len(absent_bundles)} bundle is absent from {replica.upper()}\n", V_SUMMARY)
                        for uuid, bundle_info in absent_bundles.items():
                            print(f"\t{bundle_info.get('fqid', uuid)}")

    class SearchDSSbyProjectUUID:

        def __init__(self, project_uuid, deployment, state, options):
            self.project_uuid = project_uuid
            self.deployment = deployment
            self.state = state
            self.options = options

        def check(self):
            # TODO: consider the possibility the search returns MORE bundles than Ingest knows about
            output("\tSearching DSS...", V_SUMMARY)
            dss = DataStoreAgent(self.deployment)
            query = {
                "query": {
                    "bool": {
                        "must": [
                            {
                                "match": {
                                    "files.project_json.provenance.document_id": self.project_uuid
                                }
                            }
                        ],
                        "must_not": [
                            {
                                "match": {
                                    "files.analysis_process_json.process_type.text": "analysis"
                                }
                            }
                        ]
                    }
                }
            }

            for replica in ['aws', 'gcp']:
                results = dss.search(query, replica=replica)
                for result in results:
                    with self.state.lock:
                        bundle_components = result['bundle_fqid'].split('.', 1)
                        bundle_uuid = bundle_components[0]
                        # TODO if we don't know about this bundle at all, create a new entry in bundle_map
                        bundle_info = self.state.bundle_map[bundle_uuid]
                        bundle_info[replica]['in_dss_project_search'] = True
                        if 'fqid' not in bundle_info:
                            bundle_info['fqid'] = result['bundle_fqid']
                        else:
                            assert(bundle_info['fqid'] == result['bundle_fqid'])
            output("done.\n", V_SUMMARY)

        def print_results(self):
            for replica in ['aws', 'gcp']:
                self._print_results_for_replica(replica)

        def _print_results_for_replica(self, replica):
            primary_bundles_indexed_by_project = {
                k: v for (k, v) in self.state.bundle_map.items()
                if v['type'] == 'primary'
                and v[replica].get('in_dss_project_search')
            }
            primary_bundles_not_indexed_by_project = {
                k: v for (k, v) in self.state.bundle_map.items()
                if v['type'] == 'primary'
                and not v[replica].get('in_dss_project_search')
            }
            output(f"\tIn {replica.upper()} DSS, "
                   f"{len(primary_bundles_indexed_by_project)} bundles are indexed by project\n", V_SUMMARY)
            if verbosity_level >= V_GOOD_DETAIL:
                for bundle_uuid, bundle_info in primary_bundles_indexed_by_project.items():
                    print(f"\t{bundle_info.get('fqid', bundle_uuid)}")

            if len(primary_bundles_not_indexed_by_project) > 0:
                output(
                    f"\tIn {replica.upper()} DSS, "
                    f"{len(primary_bundles_not_indexed_by_project)} bundles are NOT indexed by project\n", V_SUMMARY)
                if verbosity_level >= V_BAD_DETAIL:
                    for bundle_uuid, bundle_info in primary_bundles_not_indexed_by_project.items():
                        print(f"\t{bundle_info.get('fqid', bundle_uuid)}")

    class SearchDSSforSecondaryBundles:

        def __init__(self, deployment, state, options):
            self.deployment = deployment
            self.state = state
            self.options = options

            self.primary_bundle_count = self.state.primary_bundle_count
            self.checked_bundles = {'aws': 0, 'gcp': 0}

        def check(self):
            output("\tSearching for secondary bundles: ", V_SUMMARY)
            pool = ThreadPool(self.options.jobs)
            for pri_uuid in self.state.iter_primary_bundles():
                for replica in ['aws', 'gcp']:
                    replica_data = self.state.bundle_map[pri_uuid][replica]
                    if 'results_bundles' in replica_data:
                        if len(replica_data['results_bundles']) > 0:
                            continue
                    else:
                        replica_data['results_bundles'] = []
                    pool.add_task(self._find_secondary_bundles_for_primary_bundle, pri_uuid, replica)
            pool.wait_for_completion()
            output("...done.\n", V_SUMMARY)

        def _find_secondary_bundles_for_primary_bundle(self, pri_uuid, replica):
            dss = DataStoreAgent(self.deployment)
            query = {
                "query": {
                    "match": {
                        "files.analysis_process_json.input_bundles": pri_uuid
                    }
                }
            }
            results = dss.search(query, replica=replica)
            if len(results) > 0:
                with self.state.lock:
                    for result in results:
                        self.state.bundle_map[pri_uuid][replica]['results_bundles'].append(result['bundle_fqid'])
            self.checked_bundles[replica] += 1
            self._print_progress()

        def _print_progress(self):
            if sys.stdout.isatty():
                output(f"\r\tSearching for secondary bundles: "
                       f"AWS: {self.checked_bundles['aws']}/{self.primary_bundle_count}"
                       f" GCP: {self.checked_bundles['gcp']}/{self.primary_bundle_count}", V_SUMMARY)

        def print_results(self):
            self._print_results_for_replica('aws')
            self._print_results_for_replica('gcp')

        def _print_results_for_replica(self, replica):
            replica_results = {k: v[replica]['results_bundles'] for (k, v) in self.state.bundle_map.items()}

            i = 0
            while len(replica_results) > 0:
                pri_sec = {k: v for (k, v) in replica_results.items() if len(v) == i}
                count = len(pri_sec)
                if count > 0:
                    output(f"\tIn {replica.upper()} there are {count} primary bundles with {i} results bundles\n",
                           V_SUMMARY)
                    if i == 0 and verbosity_level >= V_BAD_DETAIL or i >= 1 and verbosity_level >= V_GOOD_DETAIL:
                        for pri, sec in pri_sec.items():
                            bundle_id = self.state.bundle_map[pri].get('fqid', pri)
                            print(f"\t\tprimary: {bundle_id} secondary: {sec}")

                for key in pri_sec.keys():
                    del replica_results[key]

                i += 1

    def __init__(self):
        parser = argparse.ArgumentParser()
        parser.add_argument('-d', '--deployment', help="search this deployment")
        parser.add_argument('submission_id')
        parser.add_argument('-v', '--verbose', default=V_SUMMARY, action='count', dest='verbosity',
                            help="provide more detail (can be added multiple times)")
        parser.add_argument('-j', '--jobs', type=int, default=10,
                            help="concurrently level to use (default: 10)")
        parser.add_argument('-f', '--fresh', action='store_true',
                            help="don't start with saved state (if present)")

        args = parser.parse_args()
        global verbosity_level
        verbosity_level = args.verbosity
        self.state = self.AnalysisState()

        self.deployment = self._choose_deployment(args)
        subm = self._retreive_submission(args.submission_id)

        project = subm.project()
        output(f"\tProject UUID: {project.uuid}\n", V_SUMMARY)
        state_filename = f"{args.submission_id}.json"

        if args.fresh or not os.path.isfile(state_filename):
            output("\nPHASE 1: Get primary bundle list from Ingest:\n", V_SUMMARY)
            self._get_primary_bundle_list_from_ingest(subm)
        else:
            output("\nPHASE 1: Loading cached state:\n", V_SUMMARY)
            self.state.load(filename=state_filename)

        output("\nPHASE 2: Checking bundles are present in DSS:\n", V_SUMMARY)
        checker2 = self.DSSBundlePresenceChecker(self.deployment, self.state, options=args)
        checker2.check()
        checker2.print_results()

        output("\nPHASE 3: Check DSS for primary bundles with this project UUID:\n", V_SUMMARY)
        checker3 = self.SearchDSSbyProjectUUID(project_uuid=project.uuid, deployment=self.deployment,
                                               state=self.state, options=args)
        checker3.check()
        checker3.print_results()

        output("\nPHASE 4: Check DSS for secondary bundles:\n", V_SUMMARY)
        checker4 =self.SearchDSSforSecondaryBundles(deployment=self.deployment, state=self.state, options=args)
        checker4.check()
        checker4.print_results()

        output(f"\nPHASE 6: Save state:\n", V_SUMMARY)
        self.state.save(state_filename)

    def _choose_deployment(self, args):
        if 'deployment' in args and args.deployment:
            deployment = args.deployment
        elif 'DEPLOYMENT_STAGE' in os.environ:
            deployment = os.environ['DEPLOYMENT_STAGE']
            answer = input(f"Use deployment {deployment}? (y/n): ")
            if answer is not 'y':
                exit(1)
        else:
            print("You must supply the --deployment argument or set environment variable DEPLOYMENT_STAGE")
            sys.exit(1)
        output(f"Using deployment: {deployment}\n", V_SUMMARY)
        return deployment

    def _retreive_submission(self, submission_id):
        output("\nRetreiving submission...", V_SUMMARY)
        finder = Finder.factory(finder_name="ingest", deployment=self.deployment)
        submission = finder.find(f"subm_id={submission_id}")
        output("done.\n", V_SUMMARY)
        output(str(submission), V_SUMMARY)
        return submission

    def _get_primary_bundle_list_from_ingest(self, subm):
        output("\tRetrieving submission's primary bundle list...", V_SUMMARY)

        for pri_uuid in subm.bundles():
            self.state.bundle_map[pri_uuid] = {
                'type': 'primary',
                'aws': {},
                'gcp': {}
            }
        output("done.\n", V_SUMMARY)

        output(f"\tIngest created {len(self.state.bundle_map)} bundles.\n", V_SUMMARY)
        if verbosity_level >= V_GOOD_DETAIL:
            for bundle_uuid in sorted(self.state.bundle_map.keys()):
                print(f"\t{bundle_uuid}")


AnalyzeSubmission()
