#!python

"""
Script mols_draw
==========================
This script is used for generating drawings of molecules.
"""

# standard
import warnings
import sys
from datetime import datetime
import logging
import pandas as pd
import argparse
import re
from PIL import Image
from PIL import ImageOps
from PIL import ImageDraw
import numpy as np
# chemoinformatics
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
# dev
import npfc
from npfc import draw
from npfc import load
from npfc import utils


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    description = """Script used drawing individual SVG files for each molecule.

    It uses the installed npfc libary in your favorite env manager.

    Example:
        >>> # Round images
        >>> mols_draw input_file.csv.gz output_dir
        >>> # Squared images
        >>> mols_draw input_file.csv.gz output_dir round False

    """
    ds = [('START', datetime.now())]
    # parameters CLI
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_file', type=str, default=None, help="Input file. Molecules to display are in column 'mol'. If a _colormap column is also avaialble, molecules are highlighted.")
    parser.add_argument('output_dir', type=str, default=None, help="Output directory where to save SVG files.")
    parser.add_argument('-s', '--size', type=str, default='(400,400)', help="Size allocated for each subimage.")
    parser.add_argument('-f', '--format', type=str, default='PNG', help="Either PNG or SVG. PNG only is supported with Cytoscape.")
    parser.add_argument('-r', '--round', type=bool, default=True, help="For PNG only, round the image so it becomes circular.")
    parser.add_argument('--atomLabelFontSize', type=float, default=14, help="Size of the atom labels.")
    parser.add_argument('--bondLineWidth', type=float, default=2.2, help="Thickness of bonds.")
    parser.add_argument('--rescale', type=float, default=1.0, help="Rescale the drawing. Smaller if between 0 and 1, Larger if greater than 1.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING MOLECULE DRAWING")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40

    # parse arguments
    utils.check_arg_input_file(args.input_file)
    utils.check_arg_output_dir(args.output_dir)
    utils.check_arg_positive_number(args.atomLabelFontSize)
    utils.check_arg_positive_number(args.bondLineWidth)

    # size
    pattern = re.compile(r'\([0-9]+,[0-9]+\)')
    size = args.size.replace(' ', '')
    if not pattern.match(size):
        raise ValueError(f"ERROR! EXPECTED SYNTAX '(xxx,xxx)' FOR SIZE ARGUMENT BUT GOT '({size})' INSTEAD!")
    size = size.replace('(', '').replace(')', '')
    input_format, input_compression = utils.get_file_format(args.input_file)
    size = tuple([int(x) for x in size.split(',')])

    # format
    format = args.format.upper()
    if format not in ('PNG', 'SVG'):
        raise ValueError(f"ERROR! EXPECTED FORMAT IS EITHER SVG OR PNG BUT GOT '({format})' INSTEAD!")
    if format == 'PNG':
        use_svg = False
    else:
        use_svg = True

    # display infos
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS:")
    logger.info("INPUT_FILE".ljust(pad) + f"{args.input_file}")
    logger.info("INPUT_FORMAT".ljust(pad) + f"{input_format}")
    logger.info("INPUT_COMPRESSION".ljust(pad) + f"{input_compression}")
    logger.info("OUTPUT_DIR".ljust(pad) + f"{args.output_dir}")
    logger.info("SIZE".ljust(pad) + f"{size}")
    logger.info("FORMAT".ljust(pad) + f"{format}")
    if format == 'PNG':
        logger.info("ROUND".ljust(pad) + f"{args.round}")
    logger.info("RESCALE".ljust(pad) + f"{args.rescale}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logger.info("BEGIN")

    logger.info("SETTING DRAWING OPTIONS")
    ds.append(('CONFIGURING JOB', datetime.now()))

    # load input_file
    logger.info("LOADING INPUT FILE")
    df = load.file(args.input_file)
    if 'mol' not in df.columns:
        logger.error("ERROR! COLUMN 'mol' NOT FOUND IN INPUT FILE, ABORTING!")
        sys.exit(1)
    nentries = len(df.index)
    df = df.groupby('idm').first().reset_index()
    if args.rescale != 1.0:
        logger.info(f"RESCALING MOLECULES WITH f={args.rescale}")
        df['mol'].map(lambda x: draw.rescale(x, f=args.rescale))  # modification takes inplace
    logger.info(f"FOUND {nentries:,d} MOLS IN {len(df.index):,d} ENTRIES")
    ds.append(('LOADING INPUT FILE', datetime.now()))

    # configuring the drawer
    # rdkit_drawer = Draw.DrawingOptions()
    # rdkit_drawer.atomLabelFontSize = args.atomLabelFontSize
    # rdkit_drawer.bondLineWidth = args.bondLineWidth

    # compute images
    logger.info("GENERATING DRAWINGS")
    for i in range(len(df.index)):
        row = df.iloc[i]
        # print(drawer)
        try:
            # img = draw.mol(row['mol'], img_size=size, svg=use_svg)
            # img = Draw.MolToImage(row['mol'], size=args.size, options=rdkit_drawer, fitImage=True)
            img = Draw.MolsToGridImage([row['mol']],
                                        molsPerRow=1,
                                        subImgSize=size,
                                        useSVG=use_svg,
                                        )
            if format == 'SVG':
                output_svg = f"{args.output_dir}/{row['idm']}.svg"
                with open(output_svg, "w") as FILE:
                    FILE.write(img)
            else:
                output_png = f"{args.output_dir}/{row['idm']}.png"
                if args.round:
                    # following procedure described at:
                    # https://stackoverflow.com/questions/51486297/cropping-an-image-in-a-circular-way-using-python
                    # https://stackoverflow.com/questions/11142851/adding-borders-to-an-image-using-python
                    img = img.convert("RGB")
                    # enlarge imqage so atoms do not get cropped while applying the round mask
                    img = ImageOps.expand(img, border=size[0]//4, fill='white')
                    # determine new image dimensions for cropping
                    npImage = np.array(img)
                    h, w = img.size
                    # create same size alpha layer with circle
                    alpha = Image.new('L', img.size, 0)
                    drawer = ImageDraw.Draw(alpha)
                    drawer.pieslice([0, 0, h, w], 0, 360, fill=255)
                    # convert alpha Image to numpy array
                    npAlpha = np.array(alpha)
                    # add alpha layer to RGB
                    npImage = np.dstack((npImage, npAlpha))
                    # crop image using the alpha layer
                    img = Image.fromarray(npImage)

                # export cropped image
                img.save(output_png)
        except UnicodeEncodeError:
            logger.error(f"ERROR! NO IMAGE GENERATED FOR MOL#{i} (idm={row['idm']}), SMILES={Chem.MolToSmiles(row['mol'])}")
    ds.append(('GENERATING DRAWINGS', datetime.now()))

    # end
    ds.append((f"END", datetime.now()))
    logger.info("SUMMARY")
    for i in range(len(ds) - 1):
        if ds[i+1][0] != 'END':
            logger.info(f"COMPUTATIONAL TIME: {ds[i+1][0]}".ljust(pad * 2) + f"{ds[i+1][1] - ds[i][1]}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{ds[-1][1] - ds[0][1]}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
