#!python

# Copyright (c) 2017 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================

import os
import random

import numpy as np
import wave
from PIL import Image
from visualdl import ROOT, LogWriter
from visualdl.server.log import logger as log

logdir = './scratch_log'

logw = LogWriter(logdir, sync_cycle=2000)

# create scalars in mode train and test.
with logw.mode('train') as logger:
    scalar0 = logger.scalar("scratch/scalar")

with logw.mode('test') as logger:
    scalar1 = logger.scalar("scratch/scalar")

# add scalar records.
last_record0 = 0.
last_record1 = 0.
for step in range(1, 100):
    last_record0 += 0.1 * (random.random() - 0.3)
    last_record1 += 0.1 * (random.random() - 0.7)
    scalar0.add_record(step, last_record0)
    scalar1.add_record(step, last_record1)

# create histogram
with logw.mode('train') as logger:
    histogram = logger.histogram("scratch/histogram", num_buckets=200)
    histogram0 = logger.histogram("scratch/histogram0", num_buckets=200)
    for step in range(1, 100):
        histogram.add_record(step,
                             np.random.normal(
                                 0.1 + step * 0.001,
                                 200. / (100 + step),
                                 size=1000))

    for step in range(1, 50):
        histogram0.add_record(step,
                              np.random.normal(
                                  0.1 + step * 0.003,
                                  200. / (120 + step),
                                  size=1000))
# create image
with logw.mode("train") as logger:
    image = logger.image("scratch/dog", 4)  # randomly sample 4 images one pass
    image0 = logger.image("scratch/random", 4)

    dog_jpg = Image.open(os.path.join(ROOT, 'python/dog.jpg'))
    dog_jpg = dog_jpg.resize(np.floor_divide(np.array(dog_jpg.size), 2))
    shape = [dog_jpg.size[1], dog_jpg.size[0], 3]

    # add dog's image
    for pass_ in range(4):
        image.start_sampling()
        for sample in range(10):
            # randomly crop a dog's image.
            target_shape = [100, 100, 3]  # width, height, channels(3 for RGB)
            left_x = random.randint(0, shape[1] - target_shape[1])
            left_y = random.randint(0, shape[0] - target_shape[0])
            right_x = left_x + target_shape[1]
            right_y = left_y + target_shape[0]

            # a more efficient way to sample images
            # check whether this image will be taken by reservoir sampling
            idx = image.is_sample_taken()
            if idx >= 0:
                data = np.array(
                    dog_jpg.crop((left_x, left_y, right_x, right_y))).flatten()
                # add this image to log
                image.set_sample(idx, target_shape, data)
            # you can also just write followig codes, it is more clear, but need to
            # process image even if it will not be sampled.
            # data = np.array(
            #     dog_jpg.crop((left_x, left_y, right_x,
            #                     right_y))).flatten()
            # image.add_sample(shape, data)

        image.finish_sampling()

    # add randomly generated image
    for pass_ in range(4):
        image0.start_sampling()
        for sample in range(10):
            shape = [40, 30, 3]
            data = np.random.random(shape).flatten()
            image0.add_sample(shape, list(data))
        image0.finish_sampling()



#create audio
with logw.mode("train") as logger:
    audio = logger.audio("scratch/audio_1", 4) # randomly sample 4 audio one pass

    CHUNK = 4096
    f = wave.open(os.path.join(ROOT, 'python/testing.wav'), "rb")
    wavdata = []
    chunk = f.readframes(CHUNK)

    while chunk:
        data = np.fromstring(chunk, dtype='uint8')
        wavdata.extend(data)
        chunk = f.readframes(CHUNK)

    for pass_ in range(4):
        audio.start_sampling()
        for sample in range(10):
            idx = audio.is_sample_taken()
            if idx >= 0:
                # 8k sample rate, 16bit frame, 1 channel
                shape = [8000, 2, 1]
                audio.set_sample(idx, shape, wavdata)

        audio.finish_sampling()

# Create text
with logw.mode("train") as logger:
    text = logger.text("scratch/generated_text1")

    ascii_a = ord('a')
    ascii_z = ord('z')
    # Generate 100 random text
    for i in range(100):
        str = ''
        for j in range(8):
            str += chr(random.randint(ascii_a, ascii_z))

        # Add a new text record to the log writer
        text.add_record(i, str)


# Create embeddings
with logw.mode("train") as logger:
    embedding = logger.embedding()

    hot_vectors = [
        [10.0, 8.04, 2],
        [8.0, 6.95, 2],
        [13.0, 7.58, 2],
        [9.0, 8.81, 3],
        [11.0, 8.33, 4],
        [14.0, 9.96, 5],
        [6.0, 7.24, 6],
        [4.0, 4.26, 7],
        [12.0, 10.84, 8],
        [7.0, 4.8, 1],
        [5.0, 5.68, 2]
    ]

    labels = [
        "yellow",
        "blue",
        "red",
        "king",
        "queen",
        "man",
        "women",
        "kid",
        "adult",
        "light",
        "dark"
    ]
    word_dict = {
        "yellow": 1,
        "blue": 2,
        "red": 3,
        "king": 4,
        "queen": 5,
        "man": 6,
        "women": 7,
        "kid": 8,
        "adult": 9,
        "light": 10,
        "dark": 0
    }
    embedding.add_embeddings_with_word_dict(hot_vectors, word_dict)


def download_onnx():
    '''
    This is a scratch demo, it do not generate a ONNX proto, but just download a prebuilt ONNX file.

    For real cases, just refer to README.
    '''

    import sys

    if sys.version_info[0] == 3:
        import urllib.request as ur

    else:
        # Not Python 3 - today, it is most likely to be Python 2
        import urllib as ur

    import ssl
    myssl = ssl.create_default_context()
    myssl.check_hostname = False
    myssl.verify_mode = ssl.CERT_NONE
    onnx_url = "https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mnist_model.onnx?raw=true"
    log.warning('download ONNX file from {}'.format(onnx_url))
    onnx_model = ur.urlopen(onnx_url, context=myssl).read()
    with open(os.path.join(logdir, 'mnist_model.onnx'), 'wb') as f:
        f.write(onnx_model)
    log.warning('ONNX model ready! use visualdl --logdir=scratch_log -m scratch_log/mnist_model.onnx to launch')


download_onnx()
