#!/usr/bin/python
# -*- coding: utf-8 -*-
# ELEKTRONN2 Toolkit
# Copyright (c) 2015 Marius F. Killinger
# All rights reserved
from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, super, zip

import sys
import os
import traceback
import time
import argparse
import ast
from subprocess import check_call, CalledProcessError, call

import numpy as np
import matplotlib

sys.setrecursionlimit(10000)


def parseargs():
    def convert(s):
        if s.lower() == 'auto':
            return 'auto'
        elif s.lower() in ['none', 'false']:
            return None
        else:
            return int(s)

    def convert_sh(sh):
        return ast.literal_eval(sh)

    parser = argparse.ArgumentParser(
        usage="$ elektronn2-profile </path/to_model_file> "
              "[--in_sh=<input_shape>] "
              "[--gpu={auto|none|<int>}] "
              "[--mpl-backend={auto|<backend_name>|force-<backend_name>}]"
    )
    parser.add_argument("model_path", type=str)
    parser.add_argument("--in_sh", default=None, type=convert_sh)
    parser.add_argument("--gpu", default=None, type=convert, choices=['auto', False, None] + list(range(0, 100)))
    parser.add_argument("--mpl-backend", type=str, default='auto')
    parsed = parser.parse_args()
    return parsed.model_path, parsed.in_sh, parsed.gpu, parsed.mpl_backend

model_path, in_sh, gpu, mpl_backend = parseargs()

if mpl_backend.lower() == 'agg':
    matplotlib.use('AGG')
    print('Using the AGG backend for matplotlib. No support for X11 windows.')
else:
    if mpl_backend.startswith('force-'):
        matplotlib.use(mpl_backend.partition('force-')[-1])
    else:
        # Prevent setting of mpl qt-backend on machines without X-server before other modules import mpl.
        with open(os.devnull, 'w') as devnull:
            try:
                # "xset q" will always succeed to run if an X server is currently running
                check_call(['xset', 'q'], stdout=devnull, stderr=devnull)
                if mpl_backend.lower() == 'auto':
                    pass  # Backend is silently set to system default.
                else:
                    matplotlib.use(mpl_backend)
                print('Using the {} backend for matplotlib.'.format(matplotlib.get_backend()))
                # Don't set backend explicitly, use system default...
            # if "xset q" fails, conclude that X is not running
            except (OSError, CalledProcessError):
                print('No X11 server found. Falling back to AGG backend for matplotlib.')
                matplotlib.use('AGG')


from elektronn2.config import config

config.use_manual_cudnn_conv = True
config.use_manual_cudnn_conv_not_w1 = False
config.use_manual_cudnn_pool = True


if config.device is None:
    from elektronn2.utils.gpu import initgpu
    initgpu(gpu)
else:
    if gpu:
        print("Cannot init gpu to %s because there is already the value %s "
              "in ~/.elektronn2rc." %(gpu, config.device))


from elektronn2.neuromancer.model import modelload


no_gc = False
mfp = False


model_path = os.path.abspath(os.path.expanduser(model_path))
model_dir, model_name = os.path.split(model_path)
os.chdir(model_dir)
f_name = model_name[:-4]+'-Speed.csv'


# Benchmark Model Compiled with static Shape
if in_sh:
    model_static = modelload(model_path, override_mfp_to_active=mfp,
                             imposed_patch_size=in_sh[2:],
                             imposed_batch_size=1, make_weights_constant=True)
    try:
        val = np.random.rand(*in_sh).astype(np.float32)
        model_static.predict(val)
        t0 = time.time()
        for i in range(3):
            y = model_static.predict(val)

        t1 = time.time()
        n = np.prod(y.shape[2:])
        speed = float(n) / (t1 - t0) / 1e6 * 3
        n = np.prod(y.shape[2:])
        if len(in_sh)==5:
            z = in_sh[2]
            x = in_sh[3]
            s = '%i\t%i\t%i\t%f\t- static\n' % (z, x, n, speed)
        else:
            x = in_sh[2]
            s = '%i\t%i\t%f\t - static\n' % (x, n, speed)

        print(s)
        with open(f_name, 'a') as f:
            f.write(s)
    except:
        traceback.print_exc(file=sys.stdout)

    exit()

else:
    model_flexi = modelload(model_path, override_mfp_to_active=mfp,
                            make_weights_constant=True)
    # TODO: Also iterate over possible batch sizes?

    fov = model_flexi.prediction_node.shape.fov
    in_sh0 = model_flexi.input_node.shape.spatial_shape
    z_ratio = float(fov[-1])/fov[0]
    ndim = len(fov)
    this_file = os.path.abspath(sys.argv[0])
    with open(f_name, 'a') as f:
        if ndim==3:
            f.write("z-shape\txy-shape\tpixel\tSpeed [MPix/s]\n")
        else:
            f.write("xy-shape\tpixel\tSpeed [MPix/s]\n")

    # Gradually enhance in_sh until the device's memory limit is reached.
    # Collect the tried in_sh values as strings to be passed to this
    # script as the "--in_sh=" argument.
    call_args = []
    try:
        if ndim==2:
            for x in range(in_sh0[-1],3000):
                in_sh = list(model_flexi.input_node.shape)
                in_sh[0] = 1
                in_sh[2] = x
                in_sh[3] = x
                print(in_sh)
                try:
                    val = np.random.rand(*in_sh).astype(np.float32)
                    y = model_flexi.predict(val)
                    t0 = time.time()
                    for i in range(10):
                        y = model_flexi.predict(val)

                    t1 = time.time()
                    n = np.prod(y.shape[2:])
                    speed = float(n) / (t1 - t0) / 1e6 * 10
                    s = '%i\t%i\t%f\n' % (x, n, speed)
                    print(s)
                    with open(f_name, 'a') as f:
                        f.write(s)

                    in_sh_str = repr(in_sh)
                    in_sh_str = in_sh_str.replace(' ', '')
                    call_args.append(in_sh_str)
                except (AssertionError, ValueError, RuntimeError) as e:
                    traceback.print_exc(file=sys.stdout)

        elif ndim==3:
            x = in_sh0[-1]+50
            x_incr = 1
            possible_z = np.arange(500)
            z_incr = 1
            while x < 2000:
                z_base = x/z_ratio
                z = possible_z[np.abs(z_base - possible_z).argmin()]
                for z_mod in range(-16,16,z_incr):
                    in_sh = list(model_flexi.input_node.shape)
                    in_sh[0] = 1
                    in_sh[2] = z + z_mod
                    in_sh[3] = x
                    in_sh[4] = x
                    print(in_sh)
                    try:
                        val = np.random.rand(*in_sh).astype(np.float32)
                        y = model_flexi.predict(val)
                        t0 = time.time()
                        for i in range(3):
                            y = model_flexi.predict(val)
                        t1 = time.time()
                        n = np.prod(y.shape[2:])
                        speed = float(n) / (t1 - t0) / 1e6 * 3
                        # z, x, static, speed
                        s = '%i\t%i\t%i\t%f\n' % (in_sh[2], x, n, speed)
                        print(s)
                        x_incr = 16
                        z_incr = 16
                        possible_z = np.arange(z,500,16)
                        with open(f_name, 'a') as f:
                            f.write(s)

                        in_sh_str = repr(in_sh)
                        in_sh_str = in_sh_str.replace(' ', '')
                        call_args.append(in_sh_str)
                    except (AssertionError, ValueError, RuntimeError) as e:
                        traceback.print_exc(file=sys.stdout)

                x += x_incr

    except MemoryError:
        print("Flexi has reached memory error")

    # Try profiling the collected input shapes individually, in reverse order
    # (the potentially best-performing ones are tried first). Many of these runs
    # are expected to fail because these automatic shapes often lead to invalid models.
    # If a model can be sucessfully compiled, its performance is evaluated and logged
    # into *-Speed.csv in the model's directory.
    for in_sh_str in call_args[::-1]:
        python = sys.executable
        if no_gc:
            call("%s %s %s --in_sh=%s" % (python, this_file, model_path, in_sh_str),
                 shell=True, executable="/bin/bash")
        else:
            call("THEANO_FLAGS='linker=cvm,allow_gc=True' %s %s %s --in_sh=%s" % (python, this_file, model_path, in_sh_str),
                 shell=True, executable="/bin/bash")
