#!/usr/bin/env python

import dapgen
import fire
import pandas as pd
import numpy as np
from typing import Optional
import structlog

log = structlog.get_logger()


def score(
    plink: str,
    weights: str,
    out: str,
    weight_col_prefix: str = None,
    freq_suffix: Optional[str] = None,
    center: bool = False,
    chrom_col: str = "CHROM",
    pos_col: str = "POS",
    ref_col: str = "REF",
    alt_col: str = "ALT",
    **kwargs,
):
    """
    Command line interface for dapgen.score

    dapgen score --plink <plink> --weights <weights> --center <center> --out <out>
    """
    log.info(
        f"Received parameters: \ndapgen score\n  "
        + "\n  ".join(f"--{k}={v}" for k, v in locals().items())
    )
    assert center in [True, False], "center must be True or False"
    log.info(f"Reading weights from {weights}")
    # parse arguments
    df_weights = pd.read_csv(weights, delim_whitespace=True)
    assert np.all(
        [col in df_weights.columns for col in [chrom_col, pos_col, ref_col, alt_col]]
    ), (
        "weights file must have these columns "
        f"CHROM:{chrom_col}, POS:{pos_col}, REF:{ref_col}, ALT:{alt_col}"
    )

    # rename columns to their prespecified names
    df_weights.rename(
        columns={chrom_col: "CHROM", pos_col: "POS", ref_col: "REF", alt_col: "ALT"},
        inplace=True,
    )
    if weight_col_prefix is None:
        weight_cols = [
            col
            for col in df_weights.select_dtypes(include=np.number).columns
            if col not in ["CHROM", "POS", "REF", "ALT"]
        ]
    else:
        weight_cols = [
            col
            for col in df_weights.select_dtypes(include=np.number).columns
            if col.startswith(weight_col_prefix)
        ]

    log.info(f"{len(df_weights)} rows x {len(weight_cols)} columns used for scoring...")
    assert len(weight_cols) > 0, "weights file must have numeric columns"

    # run
    df_score, df_snp = dapgen.score(
        plink_path=plink,
        df_weights=df_weights,
        weight_cols=weight_cols,
        center=center,
        freq_suffix=freq_suffix,
        **kwargs,
    )
    df_score.to_csv(out, sep="\t", float_format="%.6f")


if __name__ == "__main__":
    fire.Fire()