#!python
import os
import random

import numpy as np
from PIL import Image
from visualdl import ROOT, LogWriter
from visualdl.server.log import logger as log

logdir = './scratch_log'

logw = LogWriter(logdir, sync_cycle=2000)

# create scalars in mode train and test.
with logw.mode('train') as logger:
    scalar0 = logger.scalar("scratch/scalar")

with logw.mode('test') as logger:
    scalar1 = logger.scalar("scratch/scalar")

# add scalar records.
last_record0 = 0.
last_record1 = 0.
for step in range(1, 100):
    last_record0 += 0.1 * (random.random() - 0.3)
    last_record1 += 0.1 * (random.random() - 0.7)
    scalar0.add_record(step, last_record0)
    scalar1.add_record(step, last_record1)

# create histogram
with logw.mode('train') as logger:
    histogram = logger.histogram("scratch/histogram", num_buckets=200)
    histogram0 = logger.histogram("scratch/histogram0", num_buckets=200)
    for step in range(1, 100):
        histogram.add_record(step,
                             np.random.normal(
                                 0.1 + step * 0.001,
                                 200. / (100 + step),
                                 size=1000))

    for step in range(1, 50):
        histogram0.add_record(step,
                              np.random.normal(
                                  0.1 + step * 0.003,
                                  200. / (120 + step),
                                  size=1000))
# create image
with logw.mode("train") as logger:
    image = logger.image("scratch/dog", 4)  # randomly sample 4 images one pass
    image0 = logger.image("scratch/random", 4)

    dog_jpg = Image.open(os.path.join(ROOT, 'python/dog.jpg'))
    dog_jpg = dog_jpg.resize(np.floor_divide(np.array(dog_jpg.size), 2))
    shape = [dog_jpg.size[1], dog_jpg.size[0], 3]

    # add dog's image
    for pass_ in range(4):
        image.start_sampling()
        for sample in range(10):
            # randomly crop a dog's image.
            target_shape = [100, 100, 3]  # width, height, channels(3 for RGB)
            left_x = random.randint(0, shape[1] - target_shape[1])
            left_y = random.randint(0, shape[0] - target_shape[0])
            right_x = left_x + target_shape[1]
            right_y = left_y + target_shape[0]

            # a more efficient way to sample images
            # check whether this image will be taken by reservoir sampling
            idx = image.is_sample_taken()
            if idx >= 0:
                data = np.array(
                    dog_jpg.crop((left_x, left_y, right_x, right_y))).flatten()
                # add this image to log
                image.set_sample(idx, target_shape, data)
            # you can also just write followig codes, it is more clear, but need to
            # process image even if it will not be sampled.
            # data = np.array(
            #     dog_jpg.crop((left_x, left_y, right_x,
            #                     right_y))).flatten()
            # image.add_sample(shape, data)

        image.finish_sampling()

    # add randomly generated image
    for pass_ in range(4):
        image0.start_sampling()
        for sample in range(10):
            shape = [40, 30, 3]
            data = np.random.random(shape).flatten()
            image0.add_sample(shape, list(data))
        image0.finish_sampling()


def download_graph_image():
    '''
    This is a scratch demo, it do not generate a ONNX proto, but just download an image
    that generated before to show how the graph frontend works.

    For real cases, just refer to README.
    '''

    import sys

    if sys.version_info[0] == 3:
        import urllib.request as ur

    else:
        # Not Python 3 - today, it is most likely to be Python 2
        import urllib as ur

    import ssl
    myssl = ssl.create_default_context()
    myssl.check_hostname = False
    myssl.verify_mode = ssl.CERT_NONE
    image_url = "https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/super_resolution_graph.png?raw=true"
    log.warning('download graph demo from {}'.format(image_url))
    graph_image = ur.urlopen(image_url, context=myssl).read()
    with open(os.path.join(logdir, 'graph.jpg'), 'wb') as f:
        f.write(graph_image)
    log.warning('graph ready!')


download_graph_image()
