Coverage for src / remedapy / filter.py: 68%

25 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-02 10:52 +0100

1import inspect 

2from collections.abc import Callable, Iterable, Sequence 

3from typing import TypeVar, cast, overload 

4 

5T = TypeVar('T') 

6 

7 

8@overload 

9def filter(data: Iterable[T], predicate: Callable[[T], bool], /) -> Iterable[T]: ... 

10 

11 

12@overload 

13def filter(data: Iterable[T], predicate: Callable[[T, int], bool], /) -> Iterable[T]: ... 

14 

15 

16@overload 

17def filter(data: Sequence[T], predicate: Callable[[T, int, Sequence[T]], bool], /) -> Iterable[T]: ... 

18 

19 

20@overload 

21def filter(predicate: Callable[[T], bool], /) -> Callable[[Iterable[T]], Iterable[T]]: ... 

22 

23 

24@overload 

25def filter(predicate: Callable[[T, int], bool], /) -> Callable[[Iterable[T]], Iterable[T]]: ... 

26 

27 

28@overload 

29def filter(predicate: Callable[[T, int, Sequence[T]], bool], /) -> Callable[[Sequence[T]], Iterable[T]]: ... 

30 

31 

32def filter( 

33 data: Iterable[T] | Callable[[T], bool] | Callable[[T, int], bool] | Callable[[T, int, Sequence[T]], bool], 

34 callbackfn: Callable[[T], bool] | Callable[[T, int], bool] | Callable[[T, int, Sequence[T]], bool] | None = None, 

35 /, 

36) -> Iterable[T] | Callable[[Iterable[T]], Iterable[T]] | Callable[[Sequence[T]], Iterable[T]]: 

37 if callbackfn is None: 

38 callbackfn = cast(Callable[[T], bool], data) 

39 

40 def inner(data: Iterable[T], /) -> Iterable[T]: 

41 return filter(data, callbackfn) 

42 

43 return inner 

44 data = cast(list[T], data) 

45 sig = inspect.signature(callbackfn) 

46 num_params = len(sig.parameters) 

47 match num_params: 

48 case 1: 

49 callbackfn = cast(Callable[[T], bool], callbackfn) 

50 return (item for item in data if callbackfn(item)) 

51 case 2: 

52 callbackfn = cast(Callable[[T, int], bool], callbackfn) 

53 return (item for index, item in enumerate(data) if callbackfn(item, index)) 

54 case 3: 

55 callbackfn = cast(Callable[[T, int, Sequence[T]], bool], callbackfn) 

56 return (item for index, item in enumerate(data) if callbackfn(item, index, data)) 

57 case _: 

58 raise ValueError(f'Unsupported number of parameters: {num_params}')