#!/usr/bin/env python3

from multiprocessing.pool import ThreadPool
import logging
import warnings

from google.cloud import bigquery


SHARED_PROD = "moz-fx-data-shared-prod"
SHARED_NONPROD = "moz-fx-data-shar-nonprod-efed"


def create_stable(client, ref):
    live = client.get_table(ref)
    stable_dataset = live.dataset_id.replace("_live", "_stable")
    stable = bigquery.Table(
        f"{live.project}.{stable_dataset}.{live.table_id}", live.schema
    )
    prod = client.get_table(
        f"{SHARED_PROD}.{stable.dataset_id}.{stable.table_id.replace('_k8s', '')}"
    )
    for attr in [
        "clustering_fields",
        "description",
        "labels",
        "partitioning_type",
        "range_partitioning",
        "time_partitioning",
    ]:
        value = getattr(prod, attr, None)
        if value is not None:
            setattr(stable, attr, value)
    client.create_table(stable)
    logging.info(f"{stable.dataset_id}.{stable.table_id}: created")


def main():
    warnings.filterwarnings("ignore", module="google.auth._default")
    client = bigquery.Client()
    logging.root.setLevel(logging.INFO)
    with ThreadPool(10) as pool:
        pool.starmap(
            create_stable,
            [
                (client, t.reference)
                for tables in pool.map(
                    lambda d: list(client.list_tables(d)),
                    [
                        d.reference
                        for d in client.list_datasets(SHARED_NONPROD)
                        if d.dataset_id.endswith("_live")
                    ],
                )
                for t in tables
                if t.table_id.endswith("_k8s")
            ],
        )


if __name__ == "__main__":
    main()
