#!/usr/bin/env python3

from argparse import ArgumentParser, Namespace
from codecs import iterdecode
from contextlib import contextmanager, nullcontext
from enum import Enum, auto
from io import DEFAULT_BUFFER_SIZE, BufferedIOBase
from itertools import chain, count, cycle, islice, repeat
from locale import getpreferredencoding
from math import pi, sin
from os import environ, linesep
from os import name as os_name
from random import choice, randint
from shutil import get_terminal_size
from sys import stderr, stdin, stdout
from typing import Callable, Iterator, Mapping, Optional, Sequence, Tuple, cast
from unicodedata import east_asian_width


class _ColourSpace(Enum):
    EIGHT = "8"
    TRUE = "24"

    def __str__(self) -> str:
        return self.value


class _Gradient(Enum):
    D1 = "1d"
    D2 = "2d"

    def __str__(self) -> str:
        return self.value


class _Flags(Enum):
    LES = auto()
    GAY = auto()
    BI = auto()
    TRANS = auto()
    ACE = auto()
    PAN = auto()
    NB = auto()
    GQ = auto()
    MLM = auto()
    ARO = auto()
    POLY = auto()
    DEMIBOY = auto()
    DEMIGIRL = auto()
    AGENDER = auto()
    BIGENDER = auto()
    GENDERFLUID = auto()
    ABRO = auto()
    NEUTROIS = auto()
    TRIGENDER = auto()


_HexColour = str
_RGB = Tuple[int, int, int]
_RawPalette = Sequence[_HexColour]
_Palette = Sequence[_RGB]

_WINDOWS = os_name == "nt"
_COLS, _ROWS = get_terminal_size((80, 80))

_UNICODE_WIDTH_LOOKUP = {
    "W": 2,  # CJK
    "N": 0,  # Non printable
}

_ASPECT_RATIO = (3, 5)
_FLAG_SPECS: Mapping[_Flags, _RawPalette] = {
    _Flags.LES: ("#D62E02", "#FD9855", "#FFFFFF", "#D161A2", "#A20160"),
    _Flags.GAY: ("#FF0018", "#FFA52C", "#FFFF41", "#008018", "#0000F9", "#86007D"),
    _Flags.BI: ("#D60270", "#D60270", "#9B4F96", "#0038A8", "#0038A8"),
    _Flags.TRANS: ("#55CDFC", "#F7A8B8", "#FFFFFF", "#F7A8B8", "#55CDFC"),
    _Flags.ACE: ("#000000", "#A4A4A4", "#FFFFFF", "#810081"),
    _Flags.PAN: ("#FF1B8D", "#FFDA00", "#1BB3FF"),
    _Flags.NB: ("#FFF430", "#FFFFFF", "#9C59D1", "#000000"),
    _Flags.GQ: ("#B77FDD", "#FFFFFF", "#48821E"),
    _Flags.MLM: (
        "#078D70",
        "#27CEAA",
        "#98E8C1",
        "#ffffff",
        "#7bade2",
        "#5049cc",
        "#3d1a78",
    ),
    _Flags.ARO: ("#3AA63F", "#A8D47A", "#FFFFFF", "#AAAAAA", "#000000"),
    _Flags.POLY: ("#F61BB9", "#07D669", "#1C92F5"),
    _Flags.DEMIBOY: (
        "#7F7F7F",
        "#C3C3C3",
        "#9AD9EA",
        "#FFFFFF",
        "#9AD9EA",
        "#C3C3C3",
        "#7F7F7F",
    ),
    _Flags.DEMIGIRL: (
        "#7F7F7F",
        "#C3C3C3",
        "#FEAEC9",
        "#FFFFFF",
        "#FEAEC9",
        "#C3C3C3",
        "#7F7F7F",
    ),
    _Flags.AGENDER: (
        "#000000",
        "#BCC5C6",
        "#FFFFFF",
        "#B5F582",
        "#FFFFFF",
        "#BCC5C6",
        "#000000",
    ),
    _Flags.BIGENDER: (
        "#C479A0",
        "#ECA6CB",
        "#D5C7E8",
        "#FFFFFF",
        "#D5C7E8",
        "#9AC7E8",
        "#6C83CF",
    ),
    _Flags.GENDERFLUID: ("#FE75A1", "#FFFFFF", "#BE18D6", "#000000", "#333EBC"),
    _Flags.ABRO: ("#75CA92", "#B2E4C5", "#FFFFFF", "#E695B5", "#DA446C"),
    _Flags.NEUTROIS: ("#FFFFFF", "#2F9C1D", "#000000"),
    _Flags.TRIGENDER: ("#FF95C5", "#9580FF", "#67D967", "#9580FF", "#FF95C5"),
}


def _parse_args() -> Namespace:
    colour_space = (
        _ColourSpace.TRUE
        if environ.get("COLORTERM") in {"truecolor", "24bit"}
        else _ColourSpace.EIGHT
    )
    rand_flag = choice(tuple(_Flags))
    namespace = Namespace(flag=rand_flag)

    parser = ArgumentParser()
    parser.add_argument("--encoding", default=getpreferredencoding())

    with nullcontext(parser.add_argument_group()) as mode_group:
        mode_group.add_argument("-f", "--flag", dest="flag_only", action="store_true")

    with nullcontext(parser.add_argument_group()) as flag_group:
        flag_group.add_argument(
            "-l",
            "--les",
            "--lesbian",
            "--wlw",
            action="store_const",
            dest="flag",
            const=_Flags.LES,
        )
        flag_group.add_argument(
            "-g",
            "--gay",
            action="store_const",
            dest="flag",
            const=_Flags.GAY,
        )
        flag_group.add_argument(
            "-b",
            "--bi",
            "--bisexual",
            action="store_const",
            dest="flag",
            const=_Flags.BI,
        )
        flag_group.add_argument(
            "-t",
            "--trans",
            "--transgender",
            action="store_const",
            dest="flag",
            const=_Flags.TRANS,
        )
        flag_group.add_argument(
            "-a",
            "--ace",
            "--asexual",
            action="store_const",
            dest="flag",
            const=_Flags.ACE,
        )
        flag_group.add_argument(
            "-p",
            "--pan",
            "--pansexual",
            action="store_const",
            dest="flag",
            const=_Flags.PAN,
        )
        flag_group.add_argument(
            "-n",
            "--nb",
            "--non-binary",
            action="store_const",
            dest="flag",
            const=_Flags.NB,
        )
        flag_group.add_argument(
            "--gq",
            "--gender-queer",
            action="store_const",
            dest="flag",
            const=_Flags.GQ,
        )
        flag_group.add_argument(
            "--mlm",
            action="store_const",
            dest="flag",
            const=_Flags.MLM,
        )
        flag_group.add_argument(
            "--aro",
            "--aromantic",
            action="store_const",
            dest="flag",
            const=_Flags.ARO,
        )
        flag_group.add_argument(
            "--poly",
            "--polysexual",
            action="store_const",
            dest="flag",
            const=_Flags.POLY,
        )
        flag_group.add_argument(
            "--db",
            "--demiboy",
            action="store_const",
            dest="flag",
            const=_Flags.DEMIBOY,
        )
        flag_group.add_argument(
            "--dg",
            "--demigirl",
            action="store_const",
            dest="flag",
            const=_Flags.DEMIGIRL,
        )
        flag_group.add_argument(
            "--ag",
            "--agender",
            action="store_const",
            dest="flag",
            const=_Flags.AGENDER,
        )
        flag_group.add_argument(
            "--bg",
            "--bigender",
            action="store_const",
            dest="flag",
            const=_Flags.BIGENDER,
        )
        flag_group.add_argument(
            "--gf",
            "--genderfluid",
            action="store_const",
            dest="flag",
            const=_Flags.GENDERFLUID,
        )
        flag_group.add_argument(
            "--abro",
            "--abrosexual",
            action="store_const",
            dest="flag",
            const=_Flags.ABRO,
        )
        flag_group.add_argument(
            "--nt",
            "--neut",
            "--neutrois",
            action="store_const",
            dest="flag",
            const=_Flags.NEUTROIS,
        )
        flag_group.add_argument(
            "--tri",
            "--trigender",
            action="store_const",
            dest="flag",
            const=_Flags.TRIGENDER,
        )

    with nullcontext(parser.add_argument_group()) as opt_group:
        opt_group.add_argument(
            "-c",
            "--colour",
            type=_ColourSpace,
            choices=tuple(_ColourSpace),
            default=colour_space,
        )
        opt_group.add_argument(
            "-i",
            "--interpolation",
            type=_Gradient,
            choices=tuple(_Gradient),
            default=choice(tuple(_Gradient)),
        )
        opt_group.add_argument("--period", type=lambda i: max(abs(int(i)), 1))
        opt_group.add_argument(
            "--tabs", "--tab-width", type=lambda i: max(abs(int(i)), 1), default=4
        )

    return parser.parse_args(namespace=namespace)


def _trap_sig() -> None:
    if not _WINDOWS:
        from signal import Handlers, Signals, signal

        signal(Signals.SIGPIPE, Handlers.SIG_DFL)


def _stdin_stream() -> Iterator[bytes]:
    stream = cast(BufferedIOBase, stdin.buffer)
    while True:
        buf = stream.read1(DEFAULT_BUFFER_SIZE)
        if buf:
            yield buf
        else:
            break


def _stream(encoding: str) -> Iterator[str]:
    s1 = _stdin_stream()
    s2 = iterdecode(s1, encoding=encoding, errors="replace")
    return chain.from_iterable(s2)


def _normalize_width(tab_width: int, stream: Iterator[str]) -> Iterator[str]:
    for char in stream:
        if char == "\t":
            yield from repeat(" ", tab_width)
        else:
            yield char


def _unicode_width(stream: Iterator[str]) -> Iterator[Tuple[int, str]]:
    def char_width(char: str) -> int:
        try:
            code = east_asian_width(char)
            return _UNICODE_WIDTH_LOOKUP.get(code, 1)
        except Exception:
            return 1

    for char in stream:
        yield char_width(char), char


def _parse_raw_palette(raw: _RawPalette) -> _Palette:
    def parse_colour(colour: _HexColour) -> _RGB:
        hexc = colour[1:]
        it = iter(hexc)
        parsed = tuple(
            int(f"{h1}{h2}", 16) for h1, h2 in iter(lambda: tuple(islice(it, 2)), ())
        )
        return cast(_RGB, parsed)

    return tuple(parse_colour(p) for p in raw)


_DecorateChar = Callable[[_RGB], Iterator[str]]
_DecorateReset = Callable[[], Iterator[str]]


def _decor_for(
    space: _ColourSpace,
) -> Tuple[_DecorateChar, _DecorateChar, _DecorateReset]:
    def reset() -> Iterator[str]:
        yield "\x1B[0m"

    def decor_8(rgb: _RGB) -> Iterator[str]:
        r, g, b = map(lambda c: int(round(c / 255 * 5)), rgb)
        yield ";"
        yield str(16 + 36 * r + 6 * g + b)

    def decor_24(rgb: _RGB) -> Iterator[str]:
        yield from chain(*zip(repeat(";"), map(str, rgb)))

    if space == _ColourSpace.EIGHT:

        def fg(colour: _RGB) -> Iterator[str]:
            yield "\x1B[38;5"
            yield from decor_8(colour)
            yield "m"

        def bg(colour: _RGB) -> Iterator[str]:
            yield "\x1B[48;5"
            yield from decor_8(colour)
            yield "m"

        return fg, bg, reset
    elif space == _ColourSpace.TRUE:

        def fg(colour: _RGB) -> Iterator[str]:
            yield "\x1B[38;2"
            yield from decor_24(colour)
            yield "m"

        def bg(colour: _RGB) -> Iterator[str]:
            yield "\x1B[48;2"
            yield from decor_24(colour)
            yield "m"

        return fg, bg, reset
    else:
        raise ValueError()


def _paint_flag(colour_space: _ColourSpace, palette: _Palette) -> Iterator[str]:
    r, c = _ASPECT_RATIO
    _, bg, reset = _decor_for(colour_space)
    height = len(palette)
    ratio = r / c * 0.5
    multiplier = max(int(min((_ROWS - 4) / height, _COLS / height * ratio)), 1)
    line = " " * _COLS
    for colour in palette:
        for _ in range(multiplier):
            yield from bg(colour)
            yield line
            yield from reset()
            yield linesep


def _enumerate_lines(stream: Iterator[str]) -> Iterator[Tuple[bool, int, str]]:
    l_stream = _unicode_width(stream)
    prev: Optional[Tuple[int, str]] = next(l_stream, None)
    x = 0 if prev is None else 1

    def drain(ret: bool) -> Iterator[Tuple[bool, int, str]]:
        nonlocal prev
        if prev is not None:
            yield (ret, *prev)
            prev = None

    for width, char in l_stream:
        new = x + width
        if new > _COLS:
            yield from drain(True)
            prev = (width, char)
            x = width
        elif new == _COLS or char == linesep:
            yield from drain(False)
            yield True, width, char
            x = 0
        else:
            yield from drain(False)
            prev = (width, char)
            x = new

    yield from drain(False)


def _lerp(c1: _RGB, c2: _RGB, mix: float) -> _RGB:
    lhs = map(lambda c: c * mix, c1)
    rhs = map(lambda c: c * (1 - mix), c2)
    new = map(lambda c: int(round(sum(c))), zip(lhs, rhs))
    return cast(_RGB, tuple(new))


def _sine_wave(t: float) -> float:
    period = pi * pi
    x = t * period
    return (sin(x / pi + pi / 2) + 1) / 2


def _interpolate_1d(palette: _Palette, rep: int) -> Iterator[Iterator[_RGB]]:
    colours = cycle(palette)

    def once() -> Iterator[_RGB]:
        prev = next(colours)
        while True:
            curr = next(colours)
            for t in range(rep + 1):
                mix = _sine_wave(t / rep)
                yield _lerp(prev, curr, mix)
            prev = curr

    yield from repeat(once())


# contributed by https://github.com/nshepperd
# https://github.com/ms-jpq/gay/issues/2
def _interpolate_2d(palette: _Palette, rep: int) -> Iterator[Iterator[_RGB]]:
    num = len(palette)

    for y in count():

        def line() -> Iterator[_RGB]:
            for x in count():
                p = (x + 1.5 * y) / rep
                i = int(p)
                prev = palette[i % num]
                curr = palette[(i + 1) % num]
                mix = _sine_wave(p - i)
                yield _lerp(prev, curr, mix)

        yield line()


def _interpolation_for(
    mode: _Gradient, palette: _Palette, rep: Optional[int]
) -> Iterator[Iterator[_RGB]]:
    if mode == _Gradient.D1:
        period = rep or randint(10, 20)
        return _interpolate_1d(palette, period)
    if mode == _Gradient.D2:
        period = rep or randint(10, 15)
        return _interpolate_2d(palette, period)
    else:
        raise ValueError()


def _colourize(
    colour_space: _ColourSpace,
    rgb_gen: Iterator[Iterator[_RGB]],
    stream: Iterator[str],
) -> Iterator[str]:
    fg, _, reset = _decor_for(colour_space)
    colour_gen = next(rgb_gen)
    for new_line, width, char in _enumerate_lines(stream):
        if width:
            colour = next(colour_gen)
            yield from fg(colour)
        yield char
        if new_line:
            yield from reset()
            colour_gen = next(rgb_gen)


@contextmanager
def _title() -> Iterator[None]:
    atty = stderr.isatty()

    def cont(title: str) -> None:
        if atty:
            stderr.write("\x1B[0m")
            if "TMUX" in environ:
                stderr.write(f"\x1B]2{title}\x1B\\")
            else:
                stderr.write(f"\x1B]0;{title}\x1B\\")

            stderr.flush()

    cont("GAY")
    try:
        yield None
    finally:
        cont("")


def _main() -> None:
    with _title():
        _trap_sig()
        args = _parse_args()
        palette = _parse_raw_palette(_FLAG_SPECS[args.flag])

        if args.flag_only:
            flag_stripes = _paint_flag(colour_space=args.colour, palette=palette)
            stdout.writelines(flag_stripes)
        else:
            stream = _stream(args.encoding)
            normalized_stream = _normalize_width(args.tabs, stream)

            rgb_gen = _interpolation_for(
                mode=args.interpolation, palette=palette, rep=args.period
            )
            gen = _colourize(
                colour_space=args.colour,
                rgb_gen=rgb_gen,
                stream=normalized_stream,
            )
            stdout.writelines(gen)


try:
    _main()
except KeyboardInterrupt:
    exit(130)
except BrokenPipeError:
    exit(13)
