Coverage for jinja2_async_environment/loaders/function.py: 80%
50 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 14:09 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 14:09 -0700
1"""Async function template loader implementation."""
3import typing as t
5from anyio import Path as AsyncPath
6from jinja2.utils import internalcode
8from .base import AsyncBaseLoader, SourceType
10if t.TYPE_CHECKING:
11 from ..environment import AsyncEnvironment
13# Type alias for loader functions
14LoaderFunction = t.Callable[[str], str | None]
15AsyncLoaderFunction = t.Callable[[str], t.Awaitable[str | None]]
18class AsyncFunctionLoader(AsyncBaseLoader):
19 """Async function-based template loader with memory optimization.
21 This loader uses a callable function to load templates, allowing for
22 custom template retrieval logic such as loading from databases,
23 remote services, or other dynamic sources.
24 """
26 __slots__ = ("load_func", "is_async_func")
28 def __init__(
29 self,
30 load_func: LoaderFunction | AsyncLoaderFunction,
31 searchpath: AsyncPath | str | t.Sequence[AsyncPath | str] | None = None,
32 ) -> None:
33 """Initialize the function loader.
35 Args:
36 load_func: Function that takes a template name and returns
37 template source or None if not found. Can be sync or async.
38 searchpath: Path or sequence of paths for compatibility (not used)
39 """
40 # Call parent with provided searchpath or empty list
41 if searchpath is None:
42 searchpath = []
43 super().__init__(searchpath)
44 self.load_func = load_func
46 # Detect if the function is async
47 import inspect
49 self.is_async_func = inspect.iscoroutinefunction(load_func)
51 @internalcode
52 async def get_source_async(
53 self, environment: "AsyncEnvironment", name: str
54 ) -> SourceType:
55 """Get template source using the loader function asynchronously.
57 Args:
58 environment: The async environment instance
59 name: Template name to load
61 Returns:
62 Tuple of (source, filename, uptodate_func)
64 Raises:
65 TemplateNotFound: If template cannot be loaded by the function
66 """
67 self._ensure_initialized()
69 # Call the loader function (async or sync)
70 if self.is_async_func:
71 result = await self.load_func(name) # type: ignore
72 else:
73 result = self.load_func(name) # type: ignore
75 if result is None:
76 self._handle_template_not_found(name)
78 # Handle different return types from the load function
79 if isinstance(result, tuple) and len(result) == 3:
80 # load_func returned a full SourceType tuple
81 source, filename, uptodate = result
82 return source, filename, uptodate
83 elif isinstance(result, str):
84 # load_func returned just the source string
85 source = result
87 # For function loader, we use the template name as filename
88 # and create an uptodate function that re-checks the loader
89 def uptodate() -> bool:
90 try:
91 if self.is_async_func:
92 # Can't call async function from sync context
93 # Always return False to force reload
94 return False
95 else:
96 current_result = self.load_func(name) # type: ignore
97 if (
98 isinstance(current_result, tuple)
99 and len(current_result) == 3
100 ):
101 current_source = current_result[0]
102 else:
103 current_source = current_result
104 return current_result is not None and current_source == source
105 except Exception:
106 return False
108 return source, name, uptodate
109 else:
110 # Unexpected return type
111 raise TypeError(f"Unexpected source type: {type(result)}")
113 @internalcode
114 async def list_templates_async(self) -> list[str]:
115 """List templates (not supported by function loader).
117 Returns:
118 Empty list (function loaders cannot enumerate templates)
120 Raises:
121 TypeError: Always raised as function loaders cannot list templates
122 """
123 raise TypeError("this loader cannot iterate over all templates")
125 def update_function(self, load_func: LoaderFunction | AsyncLoaderFunction) -> None:
126 """Update the loader function.
128 Args:
129 load_func: New loader function to use
130 """
131 import inspect
133 self.load_func = load_func
134 self.is_async_func = inspect.iscoroutinefunction(load_func)