#!/usr/bin/env python3
import sys
import os
import argparse
import multiprocessing
import pathlib
from concurrent.futures import ThreadPoolExecutor
from timeit import default_timer as timer
import numpy as np
import czifile
import tifffile
import png

import imageio

# TODO replace with imageio
# PNG:  imageio.imwrite(filename, data, compression=self.compress, optimize=True)
# TIF:  with imageio.get_writer(filename) as w:
#           w.append_data(data, {'compress': self.compress})

# ----------------------------------------------------------------------
# Writer classes

class ImageWriter:
    def __init__(self, outprefix, extension, compress):
        self.outprefix = outprefix
        self.extension = extension

        if extension == 'png':
            self.keywords = dict(compression=compress, optimize=True)
            self.metadata = dict()
        elif extension == 'tif':
            self.keywords = dict()
            self.metadata = dict(compress=compress)


    def write_frame(self, frame, data):
        filename = f'{self.outprefix}{frame:05}.{self.extension}'
        with imageio.get_writer(filename, **self.keywords) as w:
            w.append_data(data, self.metadata)

class VideoWriter:
    def __init__(self, outprefix, size):
        self.writer = imageio.get_writer(outprefix + '.mp4', fps=30)

        # Ensure size is a multiple of 2
        size = (size[0] + size[0] % 2,
                size[1] + size[1] % 2)
        self.size = size
        self.buffer = None

    def start(self):
        pass

    def write_frame(self, frame, data):
        # Lossy conversion from 16 bit to 8 bit
        if data.dtype == np.uint16:
            data = np.right_shift(data, 8)

        # Pad or convert data if needed
        if data.shape != self.size or data.dtype != np.uint8:
            if self.buffer is None:
                self.buffer = np.zeros(self.size, dtype=np.uint8)
            width = min(self.size[0], data.shape[0])
            height = min(self.size[1], data.shape[1])
            self.buffer[0:width, 0:height] = data[0:width, 0:height]
            data = self.buffer

        # Write
        self.writer.append_data(data)

    def stop(self):
        self.writer.close()


# ----------------------------------------------------------------------
# Utilities

def ensure_dir(file_path):
    '''Ensure the directory for the given path exists'''
    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory)

def to_8bit(image, display_min, display_max):
    image.clip(display_min, display_max, out=image)
    image -= display_min
    np.floor_divide(image, (display_max - display_min + 1) / 256,
                    out=image, casting='unsafe')
    return image.astype(np.uint8)

def make_8bit_lut(display_min, display_max):
    '''Use like this: np.take(lut, image)'''
    lut = np.arange(2**16, dtype='uint16')
    return to_8bit(lut, display_min, display_max)

def minmax_subblock(directory_entry, subblock, tile, verbose=True):
    frame = directory_entry.start[1]
    if verbose and (frame < 10 or frame % 10 == 0):
        print(frame, end=' ', flush=True)
    return (np.amin(tile), np.amax(tile))

def convert_subblock(directory_entry, subblock, tile, outprefix,
                     image_writer=None, video_writer=None,
                     lut=None, verbose=True):
    """Read, decode, and copy subblock data."""
    frame = directory_entry.start[1]
    if verbose and (frame < 10 or frame % 10 == 0):
        print(frame, end=' ', flush=True)

    tile = np.squeeze(tile)

    # Optionally pass thought a look up table (which can convert to a 8-bit, for example)
    if lut is not None:
        tile = np.take(lut, tile)

    if image_writer is not None:
        image_writer.write_frame(frame, tile)

    if video_writer is not None:
        video_writer.write_frame(frame, tile)

def traverse_subblock(directory_entry, func, resize, order, **kwargs):
    subblock = directory_entry.data_segment()
    tile = subblock.data(resize=resize, order=order)
    return func(directory_entry, subblock, tile, **kwargs)

def traverse_subblocks(directory_entries, func, workers=1, resize=True, order=0, **kwargs):
    with ThreadPoolExecutor(workers) as executor:
        return list(executor.map(lambda de: traverse_subblock(de, func, resize, order, **kwargs),
                                 directory_entries))

# Specify and parse arguments
parser = argparse.ArgumentParser(description='Convert czi file to tiff images.')
parser.add_argument('infile', 
                   help='the input czi file to convert')

parser.add_argument('--outprefix', metavar="PREFIX",
                   help='prefix for the output files (default is same as input)')

parser.add_argument('--to-images', action='store_true',
                   help='convert to images')
parser.add_argument('--image-format', choices=['png','tif'], default='tif',
                   help='format for resulting images (default: %(default)s)')
parser.add_argument('--image-compression', type=int, default=9, metavar='LEVEL',
                    help='compression to apply to images 0-9 (default: %(default)s)')


parser.add_argument('--to-video', action='store_true',
                    help='convert to video')

parser.add_argument('--start-frame', type=int, default=0, metavar='N',
                   help='first frame to convert')
parser.add_argument('--end-frame', type=int, default=sys.maxsize, metavar='N',
                   help='last frame to convert')

parser.add_argument('--to-8bit', nargs='?', const='full', metavar='{full, auto, M,N}',
                   help='convert to 8 bit, "full" means use entire range of values; '
                    '"auto" means compute min and max values and use those; '
                    'two numbers separated by a comma, "M,N", specify the range to use. '
                    '(default: %(const)s)')

parser.add_argument('--workers', type=int, metavar='N', default=multiprocessing.cpu_count(),
                   help='number of frames to run in parallel (default: %(default)s)')
parser.add_argument('--quiet', action='store_true',
                   help='turn off output')
args = parser.parse_args()

# Munge arguments

verbose = not args.quiet
info = print if verbose else lambda *a, **b: None

if args.outprefix is None:
    path = pathlib.Path(args.infile)
    outdir = path.parents[0].joinpath(path.stem)
    args.outprefix = outdir.joinpath(path.stem)
    info("Outputing to", outdir, args.outprefix.parents[0])

if args.to_video:
    if not args.to_8bit:
        print("Warning: Lossy conversion from 16-bit to 8-bit for video.")

# Make sure the output directory exists
ensure_dir(args.outprefix)

# Store current time to time the whole thing
start_time = timer()

with czifile.CziFile(args.infile) as czi:
    # Throw out all but entries in the given frame range
    directory_entries = [de for de in czi.filtered_subblock_directory
                         if (args.start_frame <= de.start[1] <= args.end_frame)]

    if args.workers > 1:
        czi._fh.lock = True

    if args.to_images != None:
        image_writer = ImageWriter(args.outprefix, args.image_format, args.image_compression)
    else:
        image_writer = None

    if args.to_video:
        # For now, ignore offsets
        size = directory_entries[0].shape[3:5]
        video_writer = VideoWriter(args.outprefix, size)
        video_writer.start()
    else:
        video_writer = None

    if args.to_8bit is not None:
        if args.to_8bit == 'auto':
            # Find min max
            info("Finding min and max values... ", end='')
            res = np.array(traverse_subblocks(directory_entries, minmax_subblock, workers=args.workers,
                                              verbose=verbose))
            min_value, max_value = np.amin(res[:, 0]), np.amax(res[:, 1])
            info(f'\nMin={min_value}, Max={max_value}')
        elif args.to_8bit == 'full':
            min_value, max_value = 0, 2**16-1
        else:
            min_value, max_value = (int(val) for val in args.to_8bit.split(','))
        lut = make_8bit_lut(min_value, max_value)
    else:
        lut = None

    # Convert
    if args.to_video and args.workers > 1:
        print("Warning: Can't use multiple workers for video, using 1")
        args.workers = 1
    info("Converting... ", end='')
    traverse_subblocks(directory_entries, convert_subblock, workers=args.workers,
                       outprefix=args.outprefix, image_writer=image_writer,
                       video_writer=video_writer, lut=lut, verbose=verbose)
    info("")

    if args.workers > 1:
        czi._fh.lock = None

    if video_writer is not None:
        info("Waiting for video...")
        video_writer.stop();

info(f'Elapsed time {timer() - start_time:.2f}s')
