<%
    metrics = plotter.hash_vars['metrics']
    tpr_index = plotter.hash_vars["metrics_table"][0].index("tpr")
    fpr_index = plotter.hash_vars["metrics_table"][0].index("fpr")
    prec_index = plotter.hash_vars["metrics_table"][0].index("prec")
    coverage_index = plotter.hash_vars["metrics_table"][0].index("coverage")
    factor1_index = plotter.hash_vars["metrics_table"][0].index("factor1")
    factor2_index = plotter.hash_vars["metrics_table"][0].index("factor2")
%>

<%
base_config = {"scatterType":"line", "colorLegendTitle":" ", "titleFontStyle":"italic", "titleScaleFontFactor":0.7}
kwargs = {}

if factor1_index and factor2_index:
    base_config.update({"showLegend": True, "colorBy": "factor1", "segregateVariablesBy": "factor2"})
    kwargs["smp_attr"] = [factor1_index, factor2_index]
%>

<div style="width: 90%; background-color:#ecf0f1; margin: 0 auto;">
  <h1 style="text-align: center; background-color:#d6eaf8">Metrics curves</h1>

  <div style="display: flex; justify-content: center; gap: 30px; flex-wrap: wrap;">

    <div>
      ${plotter.scatter2D(id="metrics_table", fields=[fpr_index, tpr_index],
        header=True, row_names=False, responsive=False,
        height='400px', width='400px',
        x_label='FPR', y_label='TPR', title='ROC',
        config=base_config, **kwargs)}
    </div>

    <div>
      ${plotter.scatter2D(id="metrics_table", fields=[tpr_index, prec_index],
        header=True, row_names=False, responsive=False,
        height='400px', width='400px',
        x_label='Recall', y_label='Precision', title='Precision-Recall',
        config=base_config, **kwargs)}
    </div>

    <div>
      ${plotter.scatter2D(id="metrics_table", fields=[coverage_index, tpr_index],
        header=True, row_names=False, responsive=False,
        height='400px', width='400px',
        x_label='Coverage', y_label='TPR', title='Cumulative Detection Curve',
        config=base_config, **kwargs)}
    </div>

  </div>
</div>

<%
    name_to_index_summary = {name: idx for idx, name in enumerate(plotter.hash_vars["summary_table"][0])}
%>

<div style="width: 90%; background-color:#ecf0f1; margin: 0 auto;">
<h1 style="text-align: center; background-color:#d6eaf8">AUC</h1>
<div style="display: flex; justify-content: center; align-items: center; width: 90%; background-color:#ecf0f1; margin: 0 auto;">
    ${plotter.circular(id= "summary_table",
                                header= True, row_names= True, fields= [ name_to_index_summary.get("sample_id"), name_to_index_summary.get("roc_auc"),  name_to_index_summary.get("pr_auc")],
                                title= f"AUC summary", x_label= "AUC",
                                smp_attr= [name_to_index_summary.get("factor1"), name_to_index_summary.get("factor2")], 
                                ring_assignation= [ name_to_index_summary.get("roc_auc"),  name_to_index_summary.get("pr_auc")],
                                ringsType= ["bar", "bar"],
                                config= {"smpOverlays": ['factor1', "factor2"], "segregateSamplesBy": ["factor2"],
                                        'circularTrackName' : ["ROC","PR"],
                                        'autoScaleFont': True,
                                        "circularLetterSeparationFactor": 2, "autoScaleFont": False, "axisTickFontSize": 8, "circularOverlayThickness": 20,
                                        'circularCenterProportion': 0.5, "showSampleNames": False, "rAxisTextFontSize": 10, "color_scheme": "Tableau"}, 
                                height= 700, width= 700)}
</div>
</div>


<div style="width: 90%; background-color:#ecf0f1; margin: 0 auto;">
<h1 style="text-align: center; background-color:#d6eaf8">Summary of metrics</h1>
% for metric in metrics:
    % if metric in ["roc_auc", "pr_auc", "fpr", "coverage"]:
        <% continue %>
    % endif
    <%
        x_field = "factor1" if name_to_index_summary.get("factor1") else "sample_id"
        fields_barplot = [name_to_index_summary.get(x_field), name_to_index_summary.get(metric+"_best")]
        fields_line = [name_to_index_summary.get(x_field), name_to_index_summary.get(metric+"_best"+"_threshold")]
        config = {'setMaxX': 1, 'setMinX': 0, 'showLegend': False}
        if name_to_index_summary.get("factor2") is not None:
            smp_attr = [name_to_index_summary.get("factor2")]
            config['segregateSamplesBy'] = 'factor2'
        else:
            smp_attr = None
        
        config["graphOrientation"] = "vertical"
    %>
    <h2 style="text-align: center; background-color:#d6eaf8">${f"Metric: {metric.upper().replace("_","-")}"}</h1>
    <div style="display: flex; justify-content: center; align-items: center; width: 90%; background-color:#ecf0f1; margin: 0 auto;">
    ${plotter.barplot(
        id="summary_table",
        fields=fields_barplot,
        header=True,
        row_names=True,
        responsive=False,
        smp_attr=smp_attr,
        height='400px',
        width='400px',
        x_label=metric.upper().replace("_","-"),
        title=f"Best {metric} value",
        config=config
    )}
    ${plotter.line(
        id="summary_table",
        fields=fields_line,
        header=True,
        row_names=True,
        responsive=False,
        smp_attr=smp_attr,
        height='400px',
        width='400px',
        x_label="threshold",
        title=f"Best {metric} threshold",
        config=config
    )}
    </div>
% endfor
</div>
</div>
