Coverage for jinja2_async_environment/compiler_modules/codegen.py: 84%
512 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 14:09 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-03 14:09 -0700
1"""Async code generator for template compilation."""
3import typing as t
5from jinja2 import nodes
6from jinja2.compiler import (
7 CodeGenerator,
8 Frame,
9 find_undeclared,
10)
12from .cache import CompilationCache
13from .dependencies import DependencyResolver
14from .frame import AsyncFrame
15from .loops import LoopCodeGenerator
16from .patterns import CompiledPatterns
18if t.TYPE_CHECKING:
19 pass
21# Global compilation cache instance
22_compilation_cache = CompilationCache()
25class AsyncCodeGenerator(CodeGenerator):
26 """Async-aware code generator extending Jinja2's CodeGenerator."""
28 environment: t.Any
29 name: str
30 filename: str
31 stream: t.Any
32 extends_so_far: int
33 has_known_extends: bool
34 root_frame_class: type[AsyncFrame] = AsyncFrame
35 eval_ctx: t.Any = None
36 is_async: bool = True
37 last_identifier: int = 0
38 identifiers: dict[str, t.Any] = {}
39 import_aliases: dict[str, t.Any] = {}
40 blocks: dict[str, t.Any] = {}
41 extends_buffer: t.Any = None
42 required_blocks: set[str] = set()
43 has_super: bool = False
44 macro_frames: list[AsyncFrame] = []
46 # Fast lookup cache for common variable names
47 _COMMON_VARS = frozenset(
48 [
49 "context",
50 "environment",
51 "eval_ctx",
52 "undefined",
53 "item",
54 "loop",
55 "block",
56 "value",
57 "name",
58 "key",
59 ]
60 )
62 def __init__(
63 self, environment: t.Any, name: str, filename: str, defer_init: bool = False
64 ) -> None:
65 super().__init__(
66 environment, name, filename, stream=None, defer_init=defer_init
67 )
68 self.extends_so_far = 0
69 self.has_known_extends = False
70 self.has_super = False
71 self.last_identifier = 0
72 self.identifiers = {}
73 self.import_aliases = {}
74 self.blocks = {}
75 self.extends_buffer = None
76 self.required_blocks = set()
77 self.is_async = True
78 self.macro_frames = []
80 # Initialize assignment tracking stack
81 self._assign_stack: list[set[str]] = []
83 # Initialize utility classes for better code organization
84 self._dependency_resolver = DependencyResolver(self)
85 self._loop_generator = LoopCodeGenerator(self)
87 from jinja2.nodes import EvalContext
89 if self.eval_ctx is None:
90 self.eval_ctx = EvalContext(self.environment, self.name)
92 def choose_async(self, async_fmt: str = "async ", sync_fmt: str = "") -> str: # type: ignore[override]
93 return async_fmt if self.environment.enable_async else sync_fmt
95 def simple_write(self, value: str, frame: Frame) -> None: # type: ignore[override]
96 self.writeline(f"yield {value}")
98 def func_code_generator(self, frame: Frame) -> str:
99 async_frame = t.cast(AsyncFrame, frame)
100 return "async def" if async_frame.is_async else "def"
102 def func(self, name: str) -> str:
103 """Generate a function declaration for the given name."""
104 return f"def {name}"
106 def enter_frame(self, frame: Frame) -> None:
107 """Enter a new frame context."""
108 pass
110 def leave_frame(self, frame: Frame, with_python_scope: bool = False) -> None:
111 """Leave a frame context."""
112 pass
114 def return_buffer_contents(
115 self,
116 frame: Frame,
117 force_unescaped: bool = False, # noqa: ARG002
118 ) -> None:
119 _ = force_unescaped
120 if frame.buffer is not None:
121 self.writeline(f"return ''.join({frame.buffer})")
123 def visit_Name(self, node: nodes.Name, frame: Frame) -> None:
124 frame = t.cast(AsyncFrame, frame)
125 self._handle_assignment_tracking(node, frame)
126 if self._handle_special_names(node):
127 return
128 self._handle_symbol_name(node, frame)
130 def _handle_assignment_tracking(self, node: nodes.Name, frame: AsyncFrame) -> None:
131 if node.ctx == "store":
132 frame.symbols.store(node.name)
133 if frame.toplevel or frame.loop_frame or frame.block_frame:
134 if hasattr(self, "_assign_stack") and self._assign_stack:
135 self._assign_stack[-1].add(node.name)
137 def _handle_special_names(self, node: nodes.Name) -> bool:
138 if node.name in ("blocks", "debug_info"):
139 self.write(node.name)
140 return True
141 return False
143 def _handle_symbol_name(self, node: nodes.Name, frame: AsyncFrame) -> None:
144 # Fast path for common variables
145 if node.name in self._COMMON_VARS and node.ctx == "load":
146 try:
147 ref = frame.symbols.ref(node.name)
148 self.write(ref)
149 return
150 except AssertionError:
151 self.write(f"context.get({node.name!r})")
152 return
154 # Standard path for other variables
155 try:
156 ref = frame.symbols.ref(node.name)
157 if node.ctx == "load" and self._should_use_undefined_check(ref, frame):
158 self.write(
159 f"(undefined(name={node.name!r}) if {ref} is missing else {ref})"
160 )
161 else:
162 self.write(ref)
163 except AssertionError:
164 if node.ctx == "load":
165 self.write(f"context.get({node.name!r})")
166 else:
167 self.write(f"context.vars[{node.name!r}]")
169 def _should_use_undefined_check(self, ref: str, frame: AsyncFrame) -> bool:
170 from jinja2.compiler import VAR_LOAD_PARAMETER
172 load = frame.symbols.find_load(ref)
173 return not (
174 load is not None
175 and load[0] == VAR_LOAD_PARAMETER
176 and hasattr(self, "parameter_is_undeclared")
177 and not self.parameter_is_undeclared(ref)
178 )
180 def pull_dependencies(self, nodes: t.Iterable[nodes.Node]) -> None:
181 """Find all filter and test names used in the template and assign them to variables."""
182 from jinja2.compiler import DependencyFinderVisitor
184 visitor = DependencyFinderVisitor()
185 for node in nodes:
186 visitor.visit(node)
188 # Set up filter dependencies using utility class
189 for name in sorted(visitor.filters):
190 self._dependency_resolver.setup_filter_dependency(name)
192 # Set up test dependencies using utility class
193 for name in sorted(visitor.tests):
194 self._dependency_resolver.setup_test_dependency(name)
196 def generate(self, node: nodes.Template) -> str:
197 self.writeline(f"name = {self.name!r}")
198 self.writeline("blocks = {}")
199 self.writeline("debug_info = None")
201 # Use optimized cached imports for better performance
202 for import_line in CompiledPatterns.get_optimized_imports().split("\n"):
203 self.writeline(import_line)
205 self.writeline("def undefined(name=None, **_):")
206 self.indent()
207 self.writeline("return Undefined(name=name)")
208 self.outdent()
210 self.writeline("async def auto_await(value):")
211 self.indent()
212 self.writeline("if hasattr(value, '__await__'):")
213 self.indent()
214 self.writeline("return await value")
215 self.outdent()
216 self.writeline("return value")
217 self.outdent()
219 self.writeline("filters = DEFAULT_FILTERS.copy()")
220 self.writeline("filters['escape'] = escape")
221 self.writeline("async def root(context):")
222 self.indent()
223 self.writeline("parent_template = None")
224 self.writeline("environment = context.environment")
225 self.writeline("eval_ctx = context.eval_ctx")
226 self.writeline("undefined = environment.undefined")
228 from jinja2.nodes import EvalContext
230 if self.eval_ctx is None:
231 self.eval_ctx = EvalContext(self.environment, self.name)
233 frame = self.root_frame_class(eval_ctx=self.eval_ctx)
234 frame.toplevel = frame.rootlevel = True
235 frame.require_output_check = False
236 frame.buffer = None
238 for macro in node.find_all(nodes.Macro):
239 frame.symbols.store(macro.name)
241 # Pull dependencies for filters and tests
242 self.pull_dependencies(node.body)
243 self.blockvisit(node.body, frame)
244 self.outdent()
246 # Apply pattern-based optimizations to generated code
247 generated_code = self.stream.getvalue()
248 return CompiledPatterns.optimize_generated_code(generated_code)
250 def visit_Block(self, node: nodes.Block, frame: AsyncFrame) -> None:
251 """Visit a block node with proper async handling."""
252 # Cast frame to AsyncFrame for type safety
253 frame = t.cast(AsyncFrame, frame)
255 # Get block name
256 block_name = node.name
258 # Initialize block storage
259 self.writeline(f"blocks[{block_name!r}] = []")
261 # Define block function with async prefix
262 block_func_name = f"block_{block_name}"
263 async_prefix = self.choose_async()
264 self.writeline(f"{async_prefix}def {block_func_name}(context):")
265 self.indent()
267 # Empty block content placeholder
268 self.writeline("yield ''")
270 # Process block body if it exists
271 if node.body:
272 self.blockvisit(node.body, frame)
274 # End function and register block
275 self.outdent()
276 self.writeline(f"blocks[{block_name!r}].append({block_func_name})")
278 # Handle inheritance cases
279 level = 0
280 if frame.toplevel:
281 if self.has_known_extends:
282 return
283 if self.extends_so_far > 0:
284 self.writeline("if parent_template is None:")
285 self.indent()
286 level += 1
288 # Handle scoped blocks
289 if node.scoped:
290 context = self.derive_context(frame)
291 else:
292 context = self.get_context_ref()
294 # Generate block call with proper async handling
295 self.writeline(f"{async_prefix}for event in {block_func_name}({context}):")
296 self.indent()
297 self.writeline("yield event")
298 self.outdent()
300 # Close conditional blocks
301 for _ in range(level):
302 self.outdent()
304 def visit_Extends(self, node: nodes.Extends, frame: AsyncFrame) -> None:
305 """Visit an extends node with proper async handling."""
306 # Cast frame to AsyncFrame for type safety
307 frame = t.cast(AsyncFrame, frame)
309 # If output check is not required, raise CompilerExit immediately
310 if not frame.require_output_check:
311 from jinja2.compiler import CompilerExit
313 raise CompilerExit()
315 # Check if we're in a top-level scope
316 if not frame.toplevel:
317 self.fail("cannot use extend from a non top-level scope", node.lineno)
319 # Handle multiple extends
320 if self.extends_so_far > 0:
321 if not self.has_known_extends:
322 self.writeline("if parent_template is not None:")
323 self.indent()
324 self.writeline('raise TemplateRuntimeError("extended multiple times")')
325 if self.has_known_extends:
326 from jinja2.compiler import CompilerExit
328 raise CompilerExit()
329 else:
330 self.outdent()
332 # Generate async template loading code
333 self.writeline("parent_template = await environment.get_template_async(", node)
334 self.visit(node.template, frame)
335 self.write(f", {self.name!r})")
336 self.writeline("for name, parent_block in parent_template.blocks.items():")
337 self.indent()
338 self.writeline("context.blocks.setdefault(name, []).append(parent_block)")
339 self.outdent()
341 # Update inheritance tracking
342 if frame.rootlevel:
343 self.has_known_extends = True
344 self.extends_so_far += 1
346 def visit_Include(self, node: nodes.Include, frame: AsyncFrame) -> None:
347 """Visit an include node with proper async handling."""
348 # Cast frame to AsyncFrame for type safety
349 frame = t.cast(AsyncFrame, frame)
351 # Handle ignore_missing flag
352 if node.ignore_missing:
353 self.writeline("try:")
354 self.indent()
356 # Generate async template loading code
357 self.writeline("template = await environment.get_template_async(", node)
358 self.visit(node.template, frame)
359 self.write(f", {self.name!r})")
361 # Close try block for ignore_missing
362 if node.ignore_missing:
363 self.outdent()
364 self.writeline("except TemplateNotFound:")
365 self.indent()
366 self.writeline("pass")
367 self.outdent()
368 self.writeline("else:")
369 self.indent()
371 # Generate rendering code based on context flag
372 if node.with_context:
373 # With context - include local variables
374 local_context = self.dump_local_context(frame)
375 self.writeline(
376 f"async for event in template.root_render_func(template.new_context(context.get_all(), True, {local_context})):"
377 )
378 else:
379 # Without context - use default module
380 self.writeline(
381 "async for event in (await template._get_default_module_async())._body_stream:"
382 )
384 # Generate event output
385 self.indent()
386 self.simple_write("event", frame)
387 self.outdent()
389 # Close else block for ignore_missing
390 if node.ignore_missing:
391 self.outdent()
393 def visit_AsyncFor(self, node: nodes.Node, frame: AsyncFrame) -> None:
394 """Visit an async for loop node with proper async handling."""
395 # Cast frame to AsyncFrame for type safety
396 frame = t.cast(AsyncFrame, frame)
398 # Handle recursive loops (not supported)
399 if hasattr(node, "recursive") and node.recursive:
400 raise NotImplementedError("Recursive loops not supported")
402 # Get target variable name
403 target = node.target
404 item = target.name if hasattr(target, "name") else "item"
405 frame.symbols.store(item)
407 # Initialize target variable
408 self.writeline(f"{item} = None")
410 # Handle loop filter
411 loop_filter = None
412 if hasattr(node, "test") and node.test:
413 loop_filter = self.temporary_identifier()
414 self.writeline(f"{loop_filter} = ", node.test)
415 self.visit(node.test, frame)
417 # Initialize loop counter
418 loop_var = self.temporary_identifier()
419 self.writeline(f"{loop_var} = -1", node)
421 # Generate async for loop
422 self.writeline(f"async for {item} in ", node.iter)
423 self.visit(node.iter, frame)
424 self.write(":")
425 self.indent()
427 # Increment loop counter
428 self.writeline(f"{loop_var} += 1")
430 # Handle loop filter condition
431 if hasattr(node, "test") and node.test and loop_filter:
432 self.writeline(f"if {loop_filter}({item}):")
433 self.indent()
435 # Process loop body
436 if hasattr(node, "body"):
437 self.blockvisit(node.body, frame)
439 # Close filter condition
440 if hasattr(node, "test") and node.test and loop_filter:
441 self.outdent()
443 # Close main loop
444 self.outdent()
446 # Handle else clause
447 if hasattr(node, "else_") and node.else_:
448 self.writeline(f"if {loop_var} == -1:")
449 self.indent()
450 self.blockvisit(node.else_, frame)
451 self.outdent()
453 def visit_AsyncCall(self, node: nodes.Node, frame: AsyncFrame) -> None:
454 """Visit an async call node by adding await prefix."""
455 self.write("await ")
456 self.visit_Call(node, frame)
458 def visit_AsyncFilterBlock(self, node: nodes.Node, frame: AsyncFrame) -> None:
459 """Visit an async filter block node."""
460 # Cast frame to AsyncFrame for type safety
461 frame = t.cast(AsyncFrame, frame)
463 # Early return if no filter or body
464 if not hasattr(node, "filter"):
465 return
466 if not hasattr(node, "body"):
467 return
469 # Get filter node
470 filter_node = node.filter
472 # Create buffer for collecting content
473 buffer = self.temporary_identifier()
474 self.writeline(f"{buffer} = []")
476 # Create async frame for processing body
477 asyncframe = frame.copy()
478 asyncframe.buffer = buffer
479 asyncframe.toplevel = False
481 # Process the body
482 self.blockvisit(node.body, asyncframe)
484 # Generate await call for filter
485 self.writeline("await ", filter_node)
486 self.visit(filter_node, frame)
487 self.write(f"(''.join({buffer}))")
489 def visit_AsyncBlock(self, node: nodes.Node, frame: AsyncFrame) -> None:
490 """Visit an async block node."""
491 # Cast frame to AsyncFrame for type safety
492 frame = t.cast(AsyncFrame, frame)
494 # Early return if no name or body
495 if not hasattr(node, "name"):
496 return
497 if not hasattr(node, "body"):
498 return
500 # Get block name
501 block_name = node.name
503 # Initialize block storage
504 self.writeline(f"blocks[{block_name!r}] = []")
506 # Define async block function
507 block_func_name = f"block_{block_name}"
508 self.writeline(f"async def {block_func_name}(context):")
509 self.indent()
511 # Empty block content placeholder
512 self.writeline("yield ''")
514 # Process block body if it exists
515 if node.body:
516 self.blockvisit(node.body, frame)
518 # End function and register block
519 self.outdent()
520 self.writeline(f"blocks[{block_name!r}].append({block_func_name})")
522 def _import_common(
523 self, node: nodes.Import | nodes.FromImport, frame: AsyncFrame
524 ) -> None:
525 """Common import functionality with async template loading."""
526 # Cast frame to AsyncFrame for type safety
527 frame = t.cast(AsyncFrame, frame)
529 # Generate async template loading code
530 self.writeline("template = await environment.get_template_async(", node)
531 self.visit(node.template, frame)
532 self.write(f", {self.name!r})")
534 @classmethod
535 def compile_with_cache(
536 cls, environment: t.Any, source: str, name: str, filename: str
537 ) -> str:
538 """Compile template with caching support for improved performance."""
539 # Try to use environment's cache manager first, fall back to global cache
540 cache_manager = getattr(environment, "cache_manager", None)
541 if cache_manager:
542 # Use environment's cache manager
543 import hashlib
545 env_id = f"{id(environment)}:{getattr(environment, 'is_async', False)}"
546 content = f"{source}:{env_id}"
547 cache_key = hashlib.sha256(content.encode()).hexdigest()[:16]
549 # Check cache first
550 cached_code = cache_manager.get("compilation", cache_key)
551 if cached_code is not None:
552 return cached_code
554 # Compile and cache
555 generator = cls(environment, name, filename)
557 ast = environment.parse(source, name, filename)
558 compiled_code = generator.generate(ast)
560 # Store in cache
561 cache_manager.set("compilation", cache_key, compiled_code)
562 return compiled_code
563 else:
564 # Fall back to global cache for backward compatibility
565 env_id = f"{id(environment)}:{getattr(environment, 'is_async', False)}"
566 cache_key = _compilation_cache.get_cache_key(source, env_id)
568 # Check cache first
569 cached_code = _compilation_cache.get(cache_key)
570 if cached_code is not None:
571 return cached_code
573 # Compile and cache
574 generator = cls(environment, name, filename)
576 ast = environment.parse(source, name, filename)
577 compiled_code = generator.generate(ast)
579 # Store in cache
580 _compilation_cache.set(cache_key, compiled_code)
581 return compiled_code
583 def visit_For(self, node: nodes.For, frame: Frame) -> None:
584 frame = t.cast(AsyncFrame, frame)
585 if node.recursive:
586 raise NotImplementedError("Recursive loops not supported")
588 # Create frames and setup
589 loop_frame, test_frame, else_frame = self._setup_for_frames(frame)
590 extended_loop, loop_ref = self._setup_for_loop_context(node, loop_frame)
592 # Analyze nodes for variable declarations
593 self._analyze_for_nodes(node, loop_frame, else_frame)
595 # Handle loop filter
596 loop_filter_func = self._setup_for_filter(node, test_frame, loop_frame)
598 # Setup loop variables and checks
599 self._setup_for_variables(node, extended_loop, loop_ref)
601 # Generate main loop
602 iteration_indicator = self._generate_for_loop(
603 node, frame, loop_frame, extended_loop, loop_ref, loop_filter_func
604 )
606 # Handle else clause
607 self._handle_for_else(node, else_frame, iteration_indicator)
609 # Cleanup
610 self._cleanup_for_assignments(loop_frame)
612 def _setup_for_frames(
613 self, frame: AsyncFrame
614 ) -> tuple[AsyncFrame, AsyncFrame, AsyncFrame]:
615 """Setup frames for different scopes in for loop."""
616 loop_frame = frame.inner()
617 loop_frame.loop_frame = True
618 test_frame = frame.inner()
619 else_frame = frame.inner()
620 return loop_frame, test_frame, else_frame
622 def _setup_for_loop_context(
623 self, node: nodes.For, loop_frame: AsyncFrame
624 ) -> tuple[bool, str | None]:
625 """Setup extended loop context and loop reference."""
626 extended_loop = (
627 node.recursive
628 or "loop"
629 in find_undeclared(node.iter_child_nodes(only=("body",)), ("loop",))
630 or any(block.scoped for block in node.find_all(nodes.Block))
631 )
633 loop_ref = None
634 if extended_loop:
635 loop_ref = loop_frame.symbols.declare_parameter("loop")
637 return extended_loop, loop_ref
639 def _analyze_for_nodes(
640 self, node: nodes.For, loop_frame: AsyncFrame, else_frame: AsyncFrame
641 ) -> None:
642 """Analyze nodes for variable declarations."""
643 loop_frame.symbols.analyze_node(node, for_branch="body")
644 if node.else_:
645 else_frame.symbols.analyze_node(node, for_branch="else")
647 def _setup_for_filter(
648 self, node: nodes.For, test_frame: AsyncFrame, loop_frame: AsyncFrame
649 ) -> str | None:
650 """Setup loop filter if present."""
651 if not node.test:
652 return None
654 loop_filter_func = self.temporary_identifier()
655 test_frame.symbols.analyze_node(node, for_branch="test")
656 self.writeline(f"{self.func(loop_filter_func)}(filter):", node.test)
657 self.indent()
658 self.enter_frame(test_frame)
659 self.writeline(self.choose_async("async for ", "for "))
660 self.visit(node.target, loop_frame)
661 self.write(" in ")
662 self.write(self.choose_async("auto_aiter(filter)", "filter"))
663 self.write(":")
664 self.indent()
665 self.writeline("if ", node.test)
666 self.visit(node.test, test_frame)
667 self.write(":")
668 self.indent()
669 self.writeline("yield ")
670 self.visit(node.target, loop_frame)
671 self.outdent(3)
672 self.leave_frame(test_frame, with_python_scope=True)
673 return loop_filter_func
675 def _setup_for_variables(
676 self, node: nodes.For, extended_loop: bool, loop_ref: str | None
677 ) -> None:
678 """Setup loop variables and check for conflicts."""
679 if extended_loop and loop_ref:
680 self.writeline(f"{loop_ref} = missing")
682 for name in node.find_all(nodes.Name):
683 if name.ctx == "store" and name.name == "loop":
684 self.fail(
685 "Can't assign to special loop variable in for-loop target",
686 name.lineno,
687 )
689 def _generate_for_loop(
690 self,
691 node: nodes.For,
692 frame: AsyncFrame,
693 loop_frame: AsyncFrame,
694 extended_loop: bool,
695 loop_ref: str | None,
696 loop_filter_func: str | None,
697 ) -> str | None:
698 """Generate the main for loop code."""
699 # Handle else clause iteration indicator
700 iteration_indicator = None
701 if node.else_:
702 iteration_indicator = self.temporary_identifier()
703 self.writeline(f"{iteration_indicator} = 1")
705 # Generate the main loop using utility class
706 self._loop_generator.generate_async_for_header(node, node.target, loop_frame)
707 self._loop_generator.generate_loop_iterator(
708 node.iter, frame, extended_loop, loop_ref, loop_filter_func
709 )
711 self.indent()
712 self.enter_frame(loop_frame)
714 self.writeline("_loop_vars = {}")
715 self.blockvisit(node.body, loop_frame)
716 if node.else_:
717 self.writeline(f"{iteration_indicator} = 0")
718 self.outdent()
719 self.leave_frame(loop_frame, with_python_scope=not node.else_)
721 return iteration_indicator
723 def _handle_for_else(
724 self, node: nodes.For, else_frame: AsyncFrame, iteration_indicator: str | None
725 ) -> None:
726 """Handle the else clause of for loop."""
727 if not node.else_ or not iteration_indicator:
728 return
730 self.writeline(f"if {iteration_indicator}:")
731 self.indent()
732 self.enter_frame(else_frame)
733 self.blockvisit(node.else_, else_frame)
734 self.leave_frame(else_frame)
735 self.outdent()
737 def _cleanup_for_assignments(self, loop_frame: AsyncFrame) -> None:
738 """Clear assignments made in the loop from the top level."""
739 if hasattr(self, "_assign_stack") and self._assign_stack:
740 self._assign_stack[-1].difference_update(loop_frame.symbols.stores)
742 def visit_Macro(self, node: nodes.Macro, frame: Frame) -> None:
743 """Visit a macro node and generate async-aware code."""
744 frame = t.cast(AsyncFrame, frame)
745 # For now, let's just use the base class implementation without modification
746 # This ensures macros work in sync mode, and we can enhance async support later
747 super().visit_Macro(node, frame)
749 def visit_Filter(self, node: nodes.Filter, frame: Frame) -> None:
750 """Visit a filter node and generate async-aware code."""
751 frame = t.cast(AsyncFrame, frame)
753 filter_ref = self._get_filter_reference(node)
754 func = self.environment.filters.get(node.name)
756 if self.environment.is_async:
757 self.write("(await auto_await(")
759 self.write(f"{filter_ref}(")
760 self._write_filter_special_params(func)
761 self._write_filter_input(node, frame)
762 self._write_filter_arguments(node, frame)
763 self.write(")")
765 if self.environment.is_async:
766 self.write("))")
768 def _get_filter_reference(self, node: nodes.Filter) -> str:
769 """Get the filter reference from dependencies or fallback to environment."""
770 if node.name in self.filters:
771 return self.filters[node.name]
772 return f"environment.filters[{node.name!r}]"
774 def _write_filter_special_params(self, func: t.Any) -> None:
775 """Write special parameters that some filters need."""
776 from jinja2.compiler import _PassArg
778 pass_arg = None
779 if func:
780 pass_arg_type = _PassArg.from_obj(func)
781 if pass_arg_type:
782 pass_arg = {
783 _PassArg.context: "context",
784 _PassArg.eval_context: "context.eval_ctx",
785 _PassArg.environment: "environment",
786 }.get(pass_arg_type)
788 if pass_arg is not None:
789 self.write(f"{pass_arg}, ")
791 def _write_filter_input(self, node: nodes.Filter, frame: AsyncFrame) -> None:
792 """Write the filter input value."""
793 if node.node is not None:
794 self.visit(node.node, frame)
795 elif frame.buffer is not None:
796 self._write_buffer_content(frame)
798 def _write_buffer_content(self, frame: AsyncFrame) -> None:
799 """Write buffer content for filter blocks."""
800 if frame.eval_ctx.volatile:
801 self.write(
802 f"(Markup(concat({frame.buffer}))"
803 f" if context.eval_ctx.autoescape else concat({frame.buffer}))"
804 )
805 elif frame.eval_ctx.autoescape:
806 self.write(f"Markup(concat({frame.buffer}))")
807 else:
808 self.write(f"concat({frame.buffer})")
810 def _write_filter_arguments(self, node: nodes.Filter, frame: AsyncFrame) -> None:
811 """Write filter arguments and keyword arguments."""
812 for arg in node.args:
813 self.write(", ")
814 self.visit(arg, frame)
816 for kwarg in node.kwargs:
817 self.write(", ")
818 self.visit(kwarg, frame)
820 if node.dyn_args:
821 self.write(", *")
822 self.visit(node.dyn_args, frame)
824 if node.dyn_kwargs:
825 self.write(", **")
826 self.visit(node.dyn_kwargs, frame)
828 def visit_Assign(self, node: nodes.Assign, frame: Frame) -> None:
829 """Visit an assignment node ({% set %} statements)."""
830 frame = t.cast(AsyncFrame, frame)
831 self.push_assign_tracking()
833 # Check for namespace assignments like `ns.var = value`
834 seen_refs: set[str] = set()
835 for nsref in node.find_all(nodes.NSRef):
836 if nsref.name in seen_refs:
837 continue
838 seen_refs.add(nsref.name)
839 ref = frame.symbols.ref(nsref.name)
840 self.writeline(f"if not isinstance({ref}, Namespace):")
841 self.indent()
842 self.writeline(
843 "raise TemplateRuntimeError"
844 '("cannot assign attribute on non-namespace object")'
845 )
846 self.outdent()
848 # Generate the assignment code
849 self.newline(node)
850 self.visit(node.target, frame)
851 self.write(" = ")
852 self.visit(node.node, frame)
853 self.pop_assign_tracking(frame)
855 def push_assign_tracking(self) -> None:
856 """Push a new layer for assignment tracking."""
857 self._assign_stack.append(set())
859 def pop_assign_tracking(self, frame: Frame) -> None:
860 """Pop the topmost level for assignment tracking and update context variables."""
861 frame = t.cast(AsyncFrame, frame)
862 vars_set = self._assign_stack.pop()
864 if (
865 not frame.block_frame
866 and not frame.loop_frame
867 and not frame.toplevel
868 or not vars_set
869 ):
870 return
872 public_names = [x for x in vars_set if x[:1] != "_"]
874 if len(vars_set) == 1:
875 name = next(iter(vars_set))
876 ref = frame.symbols.ref(name)
877 if frame.loop_frame:
878 self.writeline(f"_loop_vars[{name!r}] = {ref}")
879 return
880 if frame.block_frame:
881 self.writeline(f"_block_vars[{name!r}] = {ref}")
882 return
883 self.writeline(f"context.vars[{name!r}] = {ref}")
884 else:
885 if frame.loop_frame:
886 self.writeline("_loop_vars.update({")
887 elif frame.block_frame:
888 self.writeline("_block_vars.update({")
889 else:
890 self.writeline("context.vars.update({")
891 for idx, name in enumerate(sorted(vars_set)):
892 if idx:
893 self.write(", ")
894 ref = frame.symbols.ref(name)
895 self.write(f"{name!r}: {ref}")
896 self.write("})")
898 if not frame.block_frame and not frame.loop_frame and public_names:
899 if len(public_names) == 1:
900 self.writeline(f"context.exported_vars.add({public_names[0]!r})")
901 else:
902 names_str = ", ".join(map(repr, sorted(public_names)))
903 self.writeline(f"context.exported_vars.update(({names_str}))")