#!/usr/bin/env python3

"""Apply black to code cells in jupyter notebooks."""

import sys
import os

import click
import nbformat
import black

from typing import (
    Any,
    Callable,
    Collection,
    Dict,
    Generator,
    Generic,
    Iterable,
    Iterator,
    List,
    Optional,
    Pattern,
    Sequence,
    Set,
    Tuple,
    TypeVar,
    Union,
    cast,
)


@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
@click.argument(
    "src",
    nargs=-1,
    type=click.Path(
        exists=True,
        file_okay=True,
        dir_okay=True,
        readable=True,
        allow_dash=True,
    ),
    is_eager=True,
)
@click.option(
    "-l",
    "--line-length",
    type=int,
    default=79,
    help="How many characters per line to allow.",
    show_default=True,
)
@click.option(
    "-x",
    "--exclude",
    type=str,
    default=[],
    help="Path patterns to exclude alongside '.ipynb_checkpoints'",
    show_default=True,
    multiple=True,
)
def main(src: Tuple[str], line_length: int, exclude: Iterable[str]) -> None:
    """Apply black to code cells in jupyter notebooks
    underneath the src path."""

    if not src:
        click.secho("No path given. Nothing to do 😴")
        exit(0)

    exclude = list(exclude)
    exclude.append(".ipynb_checkpoints")

    notebook_paths = find_notebooks(src, exclude)

    if not notebook_paths:
        click.secho("Can't find any notebooks. Nothing to do 😴")
        exit(0)

    for notebook_path in notebook_paths:
        click.secho(notebook_path)
        notebook = nbformat.read(notebook_path, as_version=nbformat.NO_CONVERT)
        bad_input = False
        formatted_cells = []

        for cell in notebook["cells"]:

            if cell["cell_type"] == "code":
                try:
                    cell["source"] = format_cell(cell["source"], line_length)
                except black.InvalidInput:
                    click.secho("💥 💔 💥")
                    bad_input = True
                    break

            formatted_cells.append(cell)

        if not bad_input:
            notebook["cells"] = formatted_cells
            nbformat.write(notebook, notebook_path)
            click.secho("✨ 🍰 ✨")


def format_cell(source: str, line_length: int) -> str:

    if cell_magic_in(source):
        return source

    source = "\n".join([hide_line_magic(l) for l in source.splitlines()])

    try:
        colon = source.rstrip()[-1] == ";"
    except IndexError:
        colon = False

    formatted_source = black.format_str(source, line_length)

    if colon:
        formatted_source = formatted_source.rstrip() + ";"

    return reveal_line_magic(formatted_source)


def cell_magic_in(source: str) -> bool:
    """
    Check whether this cell contains any cell magic.
    """
    return any([l.lstrip()[:2] == "%%" for l in source.splitlines()])


def hide_line_magic(line: str) -> str:
    """
    Black can't deal with cell or line magic, so we
    disguise it as a comment. This keeps it in the same
    place in the reformatted code.
    """
    try:
        return line if line[0] != "%" else "###MAGIC###" + line
    except IndexError:
        return line


def reveal_line_magic(source: str) -> str:
    """
    Reveal any notebook magic hidden by hide_magic().
    """
    return source.replace("###MAGIC###", "")


def find_notebooks(srcs: Iterable[str], exclude: str) -> Iterable[str]:
    """
    Find all notebooks below each src path.
    """
    notebook_paths = []

    for src in srcs:
        if os.path.isdir(src):
            for root, dirs, files in os.walk(src, topdown=False):
                for name in files:
                    if name[-6:].lower() == ".ipynb":
                        filepath = os.path.join(root, name)
                        if not any([ex in filepath for ex in exclude]):
                            notebook_paths.append(filepath)
        elif src[-6:].lower() == ".ipynb":
            notebook_paths.append(src)

    return notebook_paths


if __name__ == "__main__":
    main()
