#!/usr/bin/env python3

"""
Loads Graph embeddings into TrustGraph processing.
"""

import pulsar
from pulsar.schema import JsonSchema
from trustgraph.schema import GraphEmbeddings, Value
from trustgraph.schema import graph_embeddings_store_queue
import argparse
import os
import time
import pyarrow as pa
import pyarrow.parquet as pq

from trustgraph.log_level import LogLevel

class Loader:

    def __init__(
            self,
            pulsar_host,
            output_queue,
            log_level,
            file,
    ):

        self.client = pulsar.Client(
            pulsar_host,
            logger=pulsar.ConsoleLogger(log_level.to_pulsar())
        )

        self.producer = self.client.create_producer(
            topic=output_queue,
            schema=JsonSchema(GraphEmbeddings),
            chunking_enabled=True,
        )

        self.file = file

    def run(self):

        try:

            path = self.file

            print("Reading file...")
            table = pq.read_table(path)
            print("Loaded.")

            names = set(table.column_names)

            if "embeddings" not in names:
                print("No 'embeddings' column")

            if "entity" not in names:
                print("No 'entity' column")

            embc = table.column("embeddings")
            entc = table.column("entity")

            for emb, ent in zip(embc, entc):

                b = emb.as_py()
                n = ent.as_py()

                r = GraphEmbeddings(
                    vectors=b,
                    entity=Value(
                        value=n,
                        is_uri=n.startswith("https:")
                    )
                )

                self.producer.send(r)

        except Exception as e:
            print(e, flush=True)
            
    def __del__(self):
        self.client.close()

def main():

    parser = argparse.ArgumentParser(
        prog='loader',
        description=__doc__,
    )

    default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
    default_output_queue = graph_embeddings_store_queue

    parser.add_argument(
        '-p', '--pulsar-host',
        default=default_pulsar_host,
        help=f'Pulsar host (default: {default_pulsar_host})',
    )

    parser.add_argument(
        '-o', '--output-queue',
        default=default_output_queue,
        help=f'Output queue (default: {default_output_queue})'
    )

    parser.add_argument(
        '-l', '--log-level',
        type=LogLevel,
        default=LogLevel.ERROR,
        choices=list(LogLevel),
        help=f'Output queue (default: info)'
    )

    parser.add_argument(
        '-f', '--file',
        required=True,
        help=f'File to load'
    )

    args = parser.parse_args()

    while True:

        try:
            p = Loader(
                pulsar_host=args.pulsar_host,
                output_queue=args.output_queue,
                log_level=args.log_level,
                file=args.file,
            )

            p.run()

            print("File loaded.")
            break

        except Exception as e:

            print("Exception:", e, flush=True)
            print("Will retry...", flush=True)

        time.sleep(10)

main()

