#!/usr/bin/env python

import argparse
import json
import base64
from cStringIO import StringIO
import PIL.Image
import PIL.ImageDraw
import matplotlib.pyplot as plt
import numpy as np
from skimage.color import label2rgb


def labelcolormap(N=256):

    def bitget(byteval, idx):
        return ((byteval & (1 << idx)) != 0)

    cmap = np.zeros((N, 3))
    for i in xrange(0, N):
        id = i
        r, g, b = 0, 0, 0
        for j in xrange(0, 8):
            r = np.bitwise_or(r, (bitget(id, 0) << 7-j))
            g = np.bitwise_or(g, (bitget(id, 1) << 7-j))
            b = np.bitwise_or(b, (bitget(id, 2) << 7-j))
            id = (id >> 3)
        cmap[i, 0] = r
        cmap[i, 1] = g
        cmap[i, 2] = b
    cmap = cmap.astype(np.float32) / 255
    return cmap


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('json_file')
    args = parser.parse_args()

    json_file = args.json_file

    data = json.load(open(json_file))

    # string -> numpy.ndarray
    f = StringIO()
    f.write(base64.b64decode(data['imageData']))
    img = np.array(PIL.Image.open(f))

    target_names = {'background': 0}
    label = np.zeros(img.shape[:2], dtype=np.int32)
    for shape in data['shapes']:
        # get label value
        label_value = target_names.get(shape['label'])
        if label_value is None:
            label_value = len(target_names)
            target_names[shape['label']] = label_value
        # polygon -> mask
        mask = np.zeros(img.shape[:2], dtype=np.uint8)
        mask = PIL.Image.fromarray(mask)
        xy = map(tuple, shape['points'])
        PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)
        mask = np.array(mask)
        # fill label value
        label[mask == 1] = label_value

    cmap = labelcolormap(len(target_names))
    label_viz = label2rgb(label, img, colors=cmap)
    label_viz[label == target_names['background']] = 0

    plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
                        wspace=0, hspace=0)
    plt.margins(0, 0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())

    plt.subplot(121)
    plt.axis('off')
    plt.imshow(img)

    plt.subplot(122)
    plt.axis('off')
    plt.imshow(label_viz)
    # plot target_names legend
    plt_handlers = []
    plt_titles = []
    for l_name, l_value in target_names.items():
        fc = cmap[l_value]
        p = plt.Rectangle((0, 0), 1, 1, fc=fc)
        plt_handlers.append(p)
        plt_titles.append(l_name)
    plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)

    plt.show()


if __name__ == '__main__':
    main()
