#!/usr/bin/env python3
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "docopt-ng",
#     "loguru",
#     "numpy",
#     "pandas",
#     "uproot",
#     "pyarrow",
#     "fastparquet",
# ]
# ///

"""
Usage:
    amptools-to-laddu <input_file> <output_file> [--tree <treename>] [--pol-in-beam] [-n <num-entries>]

Options:
    --tree <treename>        The tree name in the ROOT file [default: kin].
    --pol-in-beam            Use the beam's momentum for polarization (eps).
    -n <num-entries>         Truncate the file to the first n entries for testing.
"""

import numpy as np
import pandas as pd
import uproot
from docopt import docopt
from loguru import logger


def read_root_file(input_file, tree_name, pol_in_beam, num_entries=None):
    """Read ROOT file and extract data with optional polarization from the beam."""
    logger.info(f"Reading ROOT file: {input_file}")
    with uproot.open(input_file) as file:
        tree = file[tree_name]
        logger.info(f"Using tree: {tree_name}")

        # Read necessary branches
        E_beam = tree["E_Beam"].array(library="np", entry_stop=num_entries)
        Px_beam = tree["Px_Beam"].array(library="np", entry_stop=num_entries)
        Py_beam = tree["Py_Beam"].array(library="np", entry_stop=num_entries)
        Pz_beam = tree["Pz_Beam"].array(library="np", entry_stop=num_entries)
        weight = tree["Weight"].array(library="np", entry_stop=num_entries)

        # Final state particles
        E_final = np.array([event for event in tree["E_FinalState"].array(library="np", entry_stop=num_entries)])
        Px_final = np.array([event for event in tree["Px_FinalState"].array(library="np", entry_stop=num_entries)])
        Py_final = np.array([event for event in tree["Py_FinalState"].array(library="np", entry_stop=num_entries)])
        Pz_final = np.array([event for event in tree["Pz_FinalState"].array(library="np", entry_stop=num_entries)])

        # Handle beam four-vector: (nevents, 4)
        p4_beam = np.stack([E_beam, Px_beam, Py_beam, Pz_beam], axis=-1)

        # Handle final state four-vectors: (nevents, nparticles, 4)
        p4_final = np.stack([E_final, Px_final, Py_final, Pz_final], axis=-1)

        # Check if EPS branch exists and update eps if needed
        if "EPS" in tree:
            logger.info("EPS branch found. Using it for eps values.")
            eps = tree["EPS"].array(library="np", entry_stop=num_entries)
            eps = eps[:, np.newaxis, :]
        if "eps" in tree:
            logger.info("eps branch found. Using it for eps values.")
            eps = tree["eps"].array(library="np", entry_stop=num_entries)
            eps = eps[:, np.newaxis, :]
        elif pol_in_beam:
            logger.info("Using beam's momentum for polarization (eps).")
            eps = np.stack([Px_beam, Py_beam, Pz_beam], axis=-1)[:, np.newaxis]
            # Reset beam momentum
            p4_beam[:, 1:] = 0  # Set Px, Py to 0
            p4_beam[:, 3] = E_beam  # Set Pz = E for beam
        else:
            logger.info("Using default or provided eps values.")
            eps = np.zeros((len(E_beam), 1, 3), dtype=np.float32)  # Default to 0

        # Concatenate the beam and final state particles: (nevents, nparticles+1, 4)
        logger.info("Concatenating beam and final state particles.")
        p4s = np.concatenate([p4_beam[:, np.newaxis, :], p4_final], axis=1)

        return p4s.astype(np.float32), weight, eps.astype(np.float32)


def save_as_parquet(p4s, weight, eps, output_file):
    """Save the processed data into Parquet format."""
    logger.info("Saving data to Parquet format.")

    # Flatten the p4s and eps into individual columns
    columns = {}
    n_particles = p4s.shape[1]
    for i in range(n_particles):
        columns[f"p4_{i}_E"] = p4s[:, i, 0]
        columns[f"p4_{i}_Px"] = p4s[:, i, 1]
        columns[f"p4_{i}_Py"] = p4s[:, i, 2]
        columns[f"p4_{i}_Pz"] = p4s[:, i, 3]

    n_eps = eps.shape[1]
    for i in range(n_eps):
        columns[f"eps_{i}_x"] = eps[:, i, 0]
        columns[f"eps_{i}_y"] = eps[:, i, 1]
        columns[f"eps_{i}_z"] = eps[:, i, 2]

    # Add weights
    columns["weight"] = weight

    # Create a DataFrame and save as Parquet
    df = pd.DataFrame(columns)
    df.to_parquet(output_file, index=False)
    logger.info(f"File saved: {output_file}")


def main():
    """Main entry point for the script."""
    args = docopt(__doc__)
    input_file = args["<input_file>"]
    output_file = args["<output_file>"]
    tree_name = args["--tree"]
    pol_in_beam = args["--pol-in-beam"]
    num_entries = int(args["-n"]) if args["-n"] else None

    logger.info(f"Starting conversion: ROOT -> Parquet")
    logger.info(f"Input file: {input_file}")
    logger.info(f"Output file: {output_file}")
    logger.info(f"Tree name: {tree_name}")
    logger.info(f"Polarization in beam: {pol_in_beam}")

    p4s, weight, eps = read_root_file(input_file, tree_name, pol_in_beam, num_entries)
    save_as_parquet(p4s, weight, eps, output_file)

    df_read = pd.read_parquet(output_file)
    print("Output Parquet File (head):")
    print(df_read.head())
    print("Output Columns:")
    for column in df_read.columns:
        print(column)


if __name__ == "__main__":
    main()
