#!/usr/bin/env python3
from __future__ import annotations

import collections
import os
import sys

import flake8_typing_imports
import mypy_extensions
import typing_extensions

if sys.version_info >= (3, 8):
    from importlib.metadata import version
else:
    from importlib_metadata import version


# --- typing_extensions notes ---
# https://github.com/python/typing_extensions#other-notes-and-limitations
# - Starting with Python 3.9, get_type_hints() has the include_extra parameter
# - get_origin and get_args lack support for Annotated in Python 3.8
#   and lack support for ParamSpecArgs and ParamSpecKwargs in 3.9.
# - Starting with 3.11, NamedTuple and TypedDict can inherit from Generic
# - @final was changed in Python 3.11 to set the .__final__ attribute
# - @overload was changed in Python 3.11 to make function overloads
#   introspectable at runtime.
# - Any was change in Python 3.11 so it can be used as a base class
# - Considered for Python 3.12
#   - PEP 695 infer_variance parameter for TypeVar
#   - PEP 696 default parameter for TypeVar, TypeVarTuple, and ParamSpec
CUSTOM_TYPING_EXT_SYMBOLS = {
    (3, 9): {'get_type_hints'},
    (3, 10): {'get_origin', 'get_args'},
    (3, 11): {'Any', 'NamedTuple', 'TypedDict', 'final', 'overload'},
    (3, 12): {'ParamSpec', 'TypeVar', 'TypeVarTuple'},
}


def main() -> int:
    flake8_typing_imports_version = version('flake8-typing-imports')
    mypy_extensions_version = version('mypy_extensions')
    typing_extensions_version = version('typing_extensions')

    mypy_extensions_all = frozenset(
        a for a in dir(mypy_extensions) if a != 'Any'
    )
    typing_extensions_all = frozenset(typing_extensions.__all__) - {
        sym for v in CUSTOM_TYPING_EXT_SYMBOLS.values() for sym in v
    }

    # some attrs are removed and then added back
    min_contiguous_versions: dict[str, flake8_typing_imports.Version] = {}
    for v, attrs in flake8_typing_imports.SYMBOLS:
        for removed in set(min_contiguous_versions) - attrs:
            del min_contiguous_versions[removed]

        for attr in attrs:
            min_contiguous_versions.setdefault(attr, v)

    symbols = collections.defaultdict(set)
    for a, v in min_contiguous_versions.items():
        symbols[v].add(a)

    # --pyXX-plus assumes the min --pyXX so group symbols by their
    # rounded up major version
    symbols_rounded_up: dict[tuple[int, int], set[str]]
    symbols_rounded_up = collections.defaultdict(set)
    for v, attrs in sorted(symbols.items()):
        symbols_rounded_up[v.major, v.minor + int(v.patch != 0)] |= attrs

    # combine 3.5 and 3.6 because this lib is 3.7+
    symbols_rounded_up[(3, 6)] |= symbols_rounded_up.pop((3, 5))

    deltas = collections.defaultdict(set)
    prev: set[str] = set()
    for v, attrs in sorted(symbols_rounded_up.items()):
        deltas[v] = attrs - prev
        prev = attrs

    replaces: dict[tuple[int, int], set[str]] = collections.defaultdict(set)
    for v, attrs in deltas.items():
        replaces[v] |= {
            f'mypy_extensions=typing:{s}'
            for s in attrs & mypy_extensions_all
        } | {
            f'typing_extensions=typing:{s}'
            for s in attrs & typing_extensions_all
        }
    for v, attrs in CUSTOM_TYPING_EXT_SYMBOLS.items():
        replaces[v] |= {f'typing_extensions=typing:{s}' for s in attrs}

    print(f'# GENERATED VIA {os.path.basename(sys.argv[0])}')
    print('# Using:')
    print(f'#     flake8-typing-imports=={flake8_typing_imports_version}')
    print(f'#     mypy-extensions=={mypy_extensions_version}')
    print(f'#     typing-extensions=={typing_extensions_version}')

    for k, v in sorted(replaces.items()):
        if not v:
            continue
        print(f'REPLACES[{k}].update((')
        for replace in sorted(v):
            print(f'    {replace!r},')
        print('))')

    print('# END GENERATED')

    return 0


if __name__ == '__main__':
    raise SystemExit(main())
