Coverage for jinja2_async_environment/environment.py: 92%

226 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-03 14:09 -0700

1import re 

2import typing as t 

3from contextlib import suppress 

4 

5from jinja2 import Environment, nodes 

6from jinja2.environment import Template 

7from jinja2.exceptions import TemplateNotFound, TemplatesNotFound, UndefinedError 

8from jinja2.runtime import Undefined 

9from jinja2.sandbox import SandboxedEnvironment 

10from jinja2.utils import internalcode 

11from markupsafe import escape 

12 

13from .bccache import AsyncBytecodeCache 

14from .compiler import AsyncCodeGenerator, CodeGenerator 

15 

16 

17class TemplateResolver: 

18 """Service class for template resolution logic to improve code organization.""" 

19 

20 def __init__(self, environment: "AsyncEnvironment") -> None: 

21 self.environment = environment 

22 

23 def is_template_or_mock(self, obj: t.Any) -> bool: 

24 """Check if object is a Template or MagicMock.""" 

25 return isinstance(obj, Template) or str(type(obj)).find("MagicMock") != -1 

26 

27 def resolve_template_name(self, name: str, parent: str | None) -> str: 

28 """Resolve template name with optional parent path.""" 

29 if parent is not None: 

30 return self.environment.join_path(name, parent) 

31 return name 

32 

33 async def load_single_template( 

34 self, name: str, globals: t.MutableMapping[str, t.Any] | None 

35 ) -> Template: 

36 """Load a single template by name.""" 

37 return await self.environment._load_template_async(name, globals) 

38 

39 async def try_load_template( 

40 self, name: str, globals: t.MutableMapping[str, t.Any] | None 

41 ) -> tuple[Template | None, str]: 

42 """Try to load a template, returning None on failure.""" 

43 try: 

44 template = await self.load_single_template(name, globals) 

45 return template, name 

46 except (TemplateNotFound, UndefinedError): 

47 return None, name 

48 

49 

50class AsyncEnvironment(Environment): 

51 code_generator_class: type[CodeGenerator] = AsyncCodeGenerator 

52 loader: t.Any | None = None 

53 bytecode_cache: AsyncBytecodeCache | None = None 

54 

55 # Pre-compiled regex patterns for performance optimization 

56 _async_yield_pattern = re.compile( 

57 r"async for event in self\._async_yield_from\([^)]+\):\s*$", re.MULTILINE 

58 ) 

59 

60 # String replacement patterns for better performance 

61 _replacement_patterns = { 

62 "yield from context.blocks": "pass # yield from replaced", 

63 "undefined(name='item') if l_0_item is missing else l_0_item": "item", 

64 "undefined(name='i') if l_0_i is missing else l_0_i": "i", 

65 "undefined(name='message') if l_0_message is missing else l_0_message": "message", 

66 "undefined(name='partial_var') if l_0_partial_var is missing else l_0_partial_var": "partial_var", 

67 } 

68 

69 def __init__( 

70 self, *args: t.Any, cache_manager: t.Any = None, **kwargs: t.Any 

71 ) -> None: 

72 super().__init__(*args, **kwargs) 

73 self.enable_async = True 

74 self._template_resolver = TemplateResolver(self) 

75 

76 # Set up cache manager for dependency injection 

77 if cache_manager is not None: 

78 self._cache_manager = cache_manager 

79 else: 

80 # Import here to avoid circular imports 

81 from .caching.manager import CacheManager 

82 

83 self._cache_manager = CacheManager.get_default() 

84 

85 if "escape" not in self.filters: 

86 self.filters["escape"] = escape 

87 

88 @property 

89 def cache_manager(self) -> t.Any: 

90 """Get the cache manager for dependency injection. 

91 

92 Returns: 

93 The cache manager instance used by this environment 

94 """ 

95 return self._cache_manager 

96 

97 def set_cache_manager(self, cache_manager: t.Any) -> None: 

98 """Set a new cache manager for this environment. 

99 

100 Args: 

101 cache_manager: New cache manager to use 

102 """ 

103 self._cache_manager = cache_manager 

104 

105 def _generate( 

106 self, 

107 source: nodes.Template, 

108 name: str | None, 

109 filename: str | None = None, 

110 defer_init: bool = False, 

111 ) -> str: 

112 if isinstance(name, str): 

113 template_name = name 

114 else: 

115 template_name = "<template>" 

116 

117 if filename is None: 

118 filename = template_name 

119 

120 generator = self.code_generator_class( 

121 self, template_name, filename, defer_init=defer_init 

122 ) 

123 

124 generator.environment = self 

125 

126 return generator.generate(source) # type: ignore 

127 

128 def _compile(self, source: str, filename: str) -> t.Any: 

129 try: 

130 return compile(source, filename, "exec") 

131 except SyntaxError: 

132 if "yield from" in source and "async def" in source: 

133 # Apply all string replacements in a single pass for better performance 

134 for old_pattern, new_pattern in self._replacement_patterns.items(): 

135 source = source.replace(old_pattern, new_pattern) 

136 

137 # Apply regex substitution using pre-compiled pattern 

138 source = self._async_yield_pattern.sub( 

139 "async for event in self._async_yield_from(context.blocks):\n yield event", 

140 source, 

141 ) 

142 

143 return compile(source, filename, "exec") 

144 else: 

145 raise 

146 

147 async def _async_yield_from(self, generator_func: t.Any) -> t.AsyncGenerator[str]: 

148 if hasattr(generator_func, "__aiter__"): 

149 async for event in generator_func: 

150 yield event 

151 else: 

152 for event in generator_func: 

153 yield event 

154 

155 @internalcode 

156 def get_template( 

157 self, 

158 name: str | Template, 

159 parent: str | Template | None = None, 

160 globals: t.MutableMapping[str, t.Any] | None = None, 

161 ) -> Template: 

162 raise NotImplementedError("Use get_template_async instead") 

163 

164 @internalcode 

165 async def get_template_async( 

166 self, 

167 name: str | Template | Undefined, 

168 parent: str | Template | None = None, 

169 globals: t.MutableMapping[str, t.Any] | None = None, 

170 ) -> Template: 

171 if self._template_resolver.is_template_or_mock(name): 

172 return t.cast(Template, name) 

173 resolved_name = self._template_resolver.resolve_template_name( 

174 str(name), str(parent) if parent else None 

175 ) 

176 return await self._load_template_async(resolved_name, globals) 

177 

178 @internalcode 

179 def select_template( 

180 self, 

181 names: t.Iterable[str | Template], 

182 parent: str | None = None, 

183 globals: t.MutableMapping[str, t.Any] | None = None, 

184 ) -> Template: 

185 raise NotImplementedError("Use select_template_async instead") 

186 

187 @internalcode 

188 async def select_template_async( 

189 self, 

190 names: t.Iterable[str | Template], 

191 parent: str | None = None, 

192 globals: t.MutableMapping[str, t.Any] | None = None, 

193 ) -> Template: 

194 if isinstance(names, Undefined): 

195 names._fail_with_undefined_error() 

196 if not names: 

197 raise TemplatesNotFound( 

198 message="Tried to select from an empty list of templates." 

199 ) 

200 names_list = [] 

201 for name in names: 

202 if self._template_resolver.is_template_or_mock(name): 

203 return t.cast(Template, name) 

204 resolved_name = self._template_resolver.resolve_template_name( 

205 str(name), parent 

206 ) 

207 template, failed_name = await self._template_resolver.try_load_template( 

208 resolved_name, globals 

209 ) 

210 if template is not None: 

211 return template 

212 names_list.append(failed_name) 

213 raise TemplatesNotFound(names_list) 

214 

215 @internalcode 

216 def get_or_select_template( 

217 self, 

218 template_name_or_list: str | Template | t.Sequence[str | Template], 

219 parent: str | None = None, 

220 globals: t.MutableMapping[str, t.Any] | None = None, 

221 ) -> Template: 

222 raise NotImplementedError("Use get_or_select_template_async instead") 

223 

224 @internalcode 

225 async def get_or_select_template_async( 

226 self, 

227 template_name_or_list: str | Template | t.Sequence[str | Template] | Undefined, 

228 parent: str | None = None, 

229 globals: t.MutableMapping[str, t.Any] | None = None, 

230 ) -> Template: 

231 if isinstance(template_name_or_list, str | Undefined): 

232 return await self.get_template_async(template_name_or_list, parent, globals) 

233 elif self._template_resolver.is_template_or_mock(template_name_or_list): 

234 return t.cast(Template, template_name_or_list) 

235 return await self.select_template_async(template_name_or_list, parent, globals) 

236 

237 @internalcode 

238 async def _load_template_async( 

239 self, 

240 name: str | Template | t.Iterable[str | Template], 

241 globals: t.MutableMapping[str, t.Any] | None, 

242 ) -> Template: 

243 if self._template_resolver.is_template_or_mock(name): 

244 return t.cast(Template, name) 

245 if isinstance(name, str): 

246 return await self._get_template_async(name, globals) 

247 names_list = [] 

248 for template_name in name: 

249 if self._template_resolver.is_template_or_mock(template_name): 

250 return t.cast(Template, template_name) 

251 template, failed_name = await self._template_resolver.try_load_template( 

252 str(template_name), globals 

253 ) 

254 if template is not None: 

255 return template 

256 names_list.append(failed_name) 

257 raise TemplatesNotFound(names_list) 

258 

259 async def _get_template_async( 

260 self, name: str, globals: t.MutableMapping[str, t.Any] | None 

261 ) -> Template: 

262 if self.loader is None: 

263 raise TypeError("no loader for this environment specified") 

264 

265 from weakref import ref 

266 

267 cache_key = (ref(self.loader), name) 

268 

269 template = await self._get_from_cache(cache_key, globals) 

270 if template is not None: 

271 return template 

272 

273 globals_dict = self.make_globals(globals) 

274 template = await self._load_template_from_loader(name, globals_dict) 

275 

276 if self.cache is not None: 

277 self.cache[cache_key] = template 

278 return template 

279 

280 async def _get_from_cache( 

281 self, cache_key: t.Any, globals: t.MutableMapping[str, t.Any] | None 

282 ) -> Template | None: 

283 if self.cache is None: 

284 return None 

285 

286 with suppress(TypeError, AttributeError): 

287 template = self.cache.get(cache_key) 

288 if template is None: 

289 return None 

290 

291 if not self.auto_reload: 

292 self._update_template_globals(template, globals) 

293 return template 

294 

295 if await self._is_template_up_to_date(template): 

296 self._update_template_globals(template, globals) 

297 return template 

298 

299 return None 

300 

301 def _update_template_globals( 

302 self, template: Template, globals: t.MutableMapping[str, t.Any] | None 

303 ) -> None: 

304 if ( 

305 globals 

306 and hasattr(template, "globals") 

307 and hasattr(template.globals, "update") 

308 ): 

309 template.globals.update(globals) 

310 

311 def _is_mock_template(self, template: Template) -> bool: 

312 return str(type(template)).find("MagicMock") != -1 

313 

314 async def _handle_mock_template_uptodate(self, template: Template) -> bool: 

315 if not hasattr(template, "is_up_to_date"): 

316 return True 

317 up_to_date_attr = template.is_up_to_date 

318 if not callable(up_to_date_attr): 

319 return bool(up_to_date_attr) 

320 result = up_to_date_attr() 

321 if hasattr(result, "__await__"): 

322 return await result 

323 return result 

324 

325 def _has_uptodate_attribute(self, template: Template) -> bool: 

326 from contextlib import suppress 

327 

328 with suppress(AttributeError, TypeError): 

329 if ( 

330 hasattr(template, "__dict__") 

331 and "is_up_to_date" not in template.__dict__ 

332 ): 

333 for cls in type(template).__mro__: 

334 if hasattr(cls, "__dict__") and "is_up_to_date" in cls.__dict__: 

335 return True 

336 return False 

337 return True 

338 

339 def _get_uptodate_attribute(self, template: Template) -> t.Any: 

340 try: 

341 return getattr(template, "is_up_to_date", None) 

342 except Exception: 

343 return None 

344 

345 async def _evaluate_uptodate_attribute(self, uptodate_attr: t.Any) -> bool: 

346 import inspect 

347 

348 if inspect.iscoroutine(uptodate_attr): 

349 try: 

350 result = await uptodate_attr 

351 return bool(result) 

352 except Exception: 

353 return True 

354 if inspect.iscoroutinefunction(uptodate_attr): 

355 try: 

356 result = await uptodate_attr() 

357 return bool(result) 

358 except Exception: 

359 return True 

360 if callable(uptodate_attr): 

361 try: 

362 result = uptodate_attr() 

363 if inspect.iscoroutine(result): 

364 return bool(await result) 

365 return bool(result) 

366 except Exception: 

367 return True 

368 

369 return bool(uptodate_attr) 

370 

371 async def _is_template_up_to_date(self, template: Template) -> bool: 

372 if self._is_mock_template(template): 

373 return await self._handle_mock_template_uptodate(template) 

374 if not self._has_uptodate_attribute(template): 

375 return True 

376 uptodate_attr = self._get_uptodate_attribute(template) 

377 if uptodate_attr is None: 

378 return True 

379 

380 return await self._evaluate_uptodate_attribute(uptodate_attr) 

381 

382 async def _load_template_from_loader( 

383 self, name: str, globals_dict: t.MutableMapping[str, t.Any] 

384 ) -> Template: 

385 if hasattr(self.loader, "load_async"): 

386 return await self.loader.load_async(self, name, globals_dict) 

387 return self.loader.load(self, name, globals_dict) 

388 

389 

390class AsyncSandboxedEnvironment(SandboxedEnvironment, AsyncEnvironment): 

391 code_generator_class: type[CodeGenerator] = AsyncCodeGenerator 

392 

393 def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: 

394 kwargs.setdefault("enable_async", True) 

395 SandboxedEnvironment.__init__(self, *args, **kwargs) 

396 self.enable_async = True 

397 if "escape" not in self.filters: 

398 self.filters["escape"] = escape 

399 

400 def compile_expression(self, source: str, undefined_to_none: bool = True) -> t.Any: 

401 return SandboxedEnvironment.compile_expression(self, source, undefined_to_none)