#!/usr/bin/env python

"""Listen to a pulseaudio mic and plot a live spectrum in the terminal."""

import argparse
import math
import shutil
import struct
import sys
import typing

import pasimple


SAMPLE_FORMAT = pasimple.PA_SAMPLE_FLOAT32LE
CHUNK_SIZE = 512
EPSILON_POWER = 1e-6


def cis(t: float) -> complex:
    """Equivalent to cmath.exp(1j * t)."""
    return complex(math.cos(t), math.sin(t))


def fft2_recursive(l: typing.Union[list[float], list[complex]]) -> list[complex]:
    """DFT using the Cooley-Tukey algorithm in pure python.

    numpy would be faster but like 70 megs on disk. ¯\_(ツ)_/¯
    """
    # base case
    if len(l) == 1:
        return list(l)
    # inductive requirement
    if len(l) % 2 != 0:
        raise ValueError('only power-of-2 lengths supported')
    # inductive case
    even = fft2_recursive(l[0::2])
    odd = fft2_recursive(l[1::2])
    shifted_odd = [cis(-2 * math.pi * k / len(l)) * o for k, o in enumerate(odd)]
    return (
        [0.5 * e + 0.5 * o for e, o in zip(even, shifted_odd)]
        + [0.5 * e - 0.5 * o for e, o in zip(even, shifted_odd)]
    )


BLOCKS = ' \u2581\u2582\u2583\u2584\u2585\u2586\u2587\u2588'


def mean(l: list[float]) -> float:
    return sum(l) / len(l) if len(l) > 0 else math.nan


class SpectrumTracker:
    """Holds some state to draw a spectrum histogram smoothed over time."""

    def __init__(self, spectrum_size):
        self.spectrum = [0. for _ in range(spectrum_size)]
        self.norm = 1.

    def plot(self, powers, width):
        # check alignment
        n = len(powers)
        if n != len(self.spectrum):
            raise ValueError(n, len(self.spectrum))

        # smooth it out over time
        for i, x in enumerate(powers):
            self.spectrum[i] = self.spectrum[i] * 0.8 + x * 0.2

        # distribute buckets on a logarithmic (musical pitch) scale
        bucket_edges = [0] + [round(n**((i + 1) / width)) for i in range(width)]

        # compute mean power in each bucket
        buckets = [
            mean(self.spectrum[bucket_edges[i] : bucket_edges[i+1]])
            for i in range(width)
        ]

        # infill empty buckets
        prev = 0
        for i, x in enumerate(buckets):
            if math.isfinite(x):
                prev = x
            else:
                buckets[i] = prev

        # normalize with moving average of max
        self.norm = max(self.norm * 0.8 + max(buckets) * 0.2 + EPSILON_POWER, max(buckets))

        # convert to ⅛ box drawing chars on a sqrt (approx perceptual loudness) scale
        eighths = [min(max(0, math.floor(math.sqrt(x / self.norm) * 9)), 8) for x in buckets]
        return ''.join(BLOCKS[x] for x in eighths)


def log_01(x: float, epsilon: float) -> float:
    """[0, 1] -> [0, 1] roughly linearly below epsilon and log above."""
    logep = math.log(epsilon)
    return (math.log(x + epsilon) - logep) / (math.log(1 + epsilon) - logep)


def unit_to_256(x: float) -> int:
    """floats in [0, 1] -> ints in [0, 255]"""
    return min(max(0, math.floor(x * 256)), 255)


def main():
    """Script entry point."""
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--device-name',
                        help="pulseaudio device name from `pactl list sources | grep Name:`"
                             " (default: use fallback)")
    parser.add_argument('--sample-rate', type=int, default=44100,
                        help="audio samples per second (default: 44100)")
    parser.add_argument('--window', type=int, default=512,
                        help="audio samples per drawn frame; powers of 2 only (default: 512)")
    args = parser.parse_args()

    log2_window = math.log2(args.window)
    if round(log2_window) != log2_window:
        raise ValueError(f'{args.window} is not a power of 2')

    with pasimple.PaSimple(
        pasimple.PA_STREAM_RECORD,
        SAMPLE_FORMAT,
        1, args.sample_rate, 'micpeek',
        fragsize=args.window * pasimple.format2width(SAMPLE_FORMAT),
        device_name=args.device_name,
    ) as audio_input:
        device_name = args.device_name if args.device_name is not None else "default device"
        print(f'Listening to {device_name}', end='\n\n')
        try:
            spectrum_tracker = SpectrumTracker(args.window // 2 + 1)
            slow_power = 0.
            fast_power = 0.
            while True:
                # grab audio samples and convert to power spectrum
                buffer = audio_input.read(args.window * pasimple.format2width(SAMPLE_FORMAT))
                samples = [x[0] for x in struct.iter_unpack('<f', buffer)]
                amplitudes = fft2_recursive(samples)
                top_frequency = len(amplitudes) // 2
                powers = [abs(z * z.conjugate()) for z in amplitudes[:top_frequency + 1]]

                # plot a dot
                width, height = shutil.get_terminal_size()
                total_power = sum(powers)
                fast_power = max(total_power, 0.5 * fast_power + 0.5 * total_power)
                slow_power = 0.9 * slow_power + 0.1 * log_01(total_power, EPSILON_POWER)
                red = green = unit_to_256(log_01(fast_power, EPSILON_POWER))
                blue = unit_to_256(slow_power)
                print('\x1b[1A', end='')  # reset cursor position
                print(f'\x1b[38;2;{red};{green};{blue}m\u2b24\x1b[m ', end='')

                # plot the spectrum in the space that remains
                print(spectrum_tracker.plot(powers, width - 3))
        except KeyboardInterrupt:
            sys.exit(130)  # exit gracefully on Ctrl+C


if __name__ == '__main__':
    main()
