#!/usr/bin/env python3

import fargv
import json
from PIL import Image
import numpy as np
import frat


p = {
    "jsons": [set([]),"The json filenames of the annotation, images are expeted to also reside in the format to ifer image size."],
    "output_postfix": [".segm.png", "This will be "],
    "multiclass":False,
    "rgb_output":True,
    "class_colors": "", #"#AAFF00,#FFFF23,#3090B0,#FFAA00,#FF2323,#9030B0,#7070F0,#80C023,#10A060,#00DD00,#6023F3,#5000F0",
    "class_names": "", #"Class 1,Class 2,Class 3,Class 4,Class 5,Class 6,Class 7,Class 8,Class 9,Class 10,Class 11,Class 12",
    "numbers_for_colors": False,
}


if __name__ == "__main__":
    p, _ = fargv.fargv(p)
    #assert not p.multiclass or not p.color_palette #  p.multiclass and p.color_palette are incompatible
    #assert (not p.color_palette) or p.color_palette and p.rgb_output  # color_palette => rgb_output
    assert len(p.class_colors.split(",")) == len(p.class_names.split(","))
    class_colors = [[166,206,227],[31,120,180],[178,223,138],[51,160,44],[251,154,153],[227,26,28],[253,191,111],[255,127,0],[202,178,214]]
    for json_name in p.jsons:
        assert json_name.lower().endswith(".json")
        gt = json.load(open(json_name,"r"))
        img = Image.open(json_name[:-5])
        rectangles, classids, class_colors = gt["rect_LTRB"], gt["rect_classes"], gt["class_colors"]
        if p.numbers_for_colors:
            class_colors = [f"#{n:06x}".upper() for n in range(len(gt["class_colors"]))]
            class_colors = frat.htmlrgb_to_uint8(class_colors)
        elif p.class_colors!="":
            class_colors = frat.htmlrgb_to_uint8(p.class_colors)
        else:
            class_colors = frat.htmlrgb_to_uint8(gt["class_colors"])

        assert len(gt["class_names"])
        if p.rgb_output:
            img = frat.rectangles_to_rgb(rectangles, classids, img.size, class_colors, p.multiclass)
        else:
            img = frat.rectangles_to_gray(rectangles, classids, img.size, class_colors, p.multiclass)
        print(img.shape, img.dtype)
        img = Image.fromarray(img).save(json_name[:-5]+p.output_postfix)
