#!/usr/bin/env python

from __future__ import print_function

import binascii
import math
import struct
import sys
import zlib

MAGIC = b'\x89PNG\r\n\x1a\n'


class PNGWriter(object):
    def __init__(self, compress=False, width=0, height=0):
        self.compress = compress
        self.width  = width
        self.height = height
        self.depth  = 8
        self.color  = 2
        self.data = []

    def pixels(self, data):
        self.data.append(data)

    def __chunk(self, name, data):
        size = struct.pack('>I', len(data))
        crc  = struct.pack('>I', binascii.crc32(name + data) & 0xffffffff)
        return size + name + data + crc

    def save(self):
        roundup = lambda val: int(math.ceil(val))

        data = b''.join(self.data)

        if self.compress:
            data = zlib.compress(data)

        data = struct.pack('>I?', len(data), self.compress) + data

        if self.width == 0:
            w = roundup(((len(data) / 3.0) * (1.6)) ** 0.5)
            if (w % 160) != 0:
                w += 160 - (w % 160)
            self.width = w

        if self.height == 0:
            self.height = roundup((len(data) / 3.0) / self.width)

        data += b'\x80' * int((self.width * self.height * 3) - len(data))

        output = [MAGIC]

        ihdr = struct.pack('>IIBBBBB', self.width, self.height, self.depth, self.color, 0, 0, 0)
        output.append(self.__chunk(b'IHDR', ihdr))

        pixels = []
        for offset in range(0, len(data), self.width * 3):
            pixels.append(b'\x00' + data[offset:offset+self.width*3])

        output.append(self.__chunk(b'IDAT', zlib.compress(b''.join(pixels))))
        output.append(self.__chunk(b'IEND', b''))

        return b''.join(output)


class PNGReader(object):
    def __init__(self):
        self.data = []

    @classmethod
    def read(cls, png):
        self = cls()
        self.parse(png)
        data = b''.join(self.data)
        size,compress = struct.unpack('>I?', data[:5])
        data = data[5:5+size]
        if compress:
            data = zlib.decompress(data)
        return data

    def parse(self, png):
        if not png.startswith(MAGIC):
            raise ValueError('Invalid PNG')
        png = png[len(MAGIC):]

        while png:
            size,name = struct.unpack('>I4s', png[:8])
            name = name.decode('ascii')
            crc = binascii.crc32(png[4:8+size]) & 0xffffffff
            check, = struct.unpack('>I', png[8+size:12+size])

            if crc != check:
                raise ValueError('CRC Mismatch')

            data = png[8:8+size]
            handler = getattr(self, 'on_'+name, None)

            if handler and handler(data):
                break

            png = png[12+size:]


    def on_IHDR(self, data):
        (
            self.width,
            self.height,
            self.depth,
            self.color,
            self.compression,
            self.filter,
            self.interlace,
        ) = struct.unpack('>IIBBBBB', data)

    def on_IDAT(self, data):
        data = bytearray(zlib.decompress(data))
        if len(data) != self.width * self.height * 3 + self.height:
            raise ValueError('Bad data')

        for offset in range(0, len(data), self.width * 3 + 1):
            if data[offset] != 0:
                raise ValueError('Unsupported filter')
            self.data.append(bytes(data[offset+1:offset+1+self.width * 3]))

    def on_IEND(self, data):
        return True


if '__main__' == __name__:
    import argparse

    # python 2/3 support
    stdin  = getattr(sys.stdin,  'buffer', sys.stdin)
    stdout = getattr(sys.stdout, 'buffer', sys.stdout)

    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--compress',
        action='store_true',
        help='compress data to be stored')
    parser.add_argument('input',
        nargs='?', type=argparse.FileType('rb'), default=stdin,
        help='file to read data from (default stdin)')
    parser.add_argument('output',
        nargs='?', type=argparse.FileType('wb'), default=stdout,
        help='file to write data to (default stdout)')
    args = parser.parse_args()

    data = args.input.read()

    if not data.startswith(MAGIC):
        png = PNGWriter(args.compress)
        png.pixels(data)
        data = png.save()
    else:
        data = PNGReader.read(data)

    args.output.write(data)
