Coverage for netrun / rbac / testing.py: 24%
294 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-18 22:20 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-18 22:20 +0000
1"""
2Tenant Isolation Contract Testing Utilities for Multi-Tenant Applications.
4Provides pgTAP-style assertions for verifying multi-tenant data isolation.
5Use in integration tests to prove cross-tenant data access is impossible.
7This module is CRITICAL for security compliance (SOC2, ISO27001, NIST).
9Features:
10- Query analysis to detect missing tenant filters
11- Test context management for multi-tenant scenarios
12- Background task tenant context preservation
13- Escape path detection and prevention
14- CI/CD integration utilities
16Usage:
17 from netrun.rbac.testing import (
18 TenantIsolationError,
19 assert_tenant_isolation,
20 TenantTestContext,
21 BackgroundTaskTenantContext,
22 TenantEscapePathScanner,
23 )
25 # Assert query includes tenant filter
26 query = select(Item).where(Item.tenant_id == tenant_id)
27 await assert_tenant_isolation(query)
29 # Test cross-tenant isolation
30 async with TenantTestContext(session) as ctx:
31 # Create data in tenant A
32 item = Item(name="Secret", tenant_id=ctx.tenant_a_id)
33 session.add(item)
35 # Switch to tenant B and verify isolation
36 await ctx.switch_to_tenant_b()
37 items = await session.execute(select(Item))
38 assert len(items.scalars().all()) == 0
40Security Level: CRITICAL
41Compliance: SOC2 CC6.1, ISO27001 A.9.4, NIST AC-4
42"""
44from __future__ import annotations
46import asyncio
47import contextvars
48import functools
49import logging
50import re
51import warnings
52from contextlib import asynccontextmanager
53from dataclasses import dataclass, field
54from enum import Enum
55from typing import (
56 Any,
57 AsyncGenerator,
58 Callable,
59 Coroutine,
60 Dict,
61 List,
62 Optional,
63 Pattern,
64 Protocol,
65 Set,
66 Tuple,
67 TypeVar,
68 Union,
69)
70from uuid import uuid4
72from sqlalchemy import Select, text
73from sqlalchemy.ext.asyncio import AsyncSession
74from sqlalchemy.sql import ClauseElement
76from .exceptions import TenantIsolationError
78logger = logging.getLogger(__name__)
80# Type variables
81T = TypeVar("T")
82AsyncFunc = TypeVar("AsyncFunc", bound=Callable[..., Coroutine[Any, Any, Any]])
84# Context variable for current tenant in async context
85_current_tenant_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
86 "current_tenant_id", default=None
87)
90class EscapePathSeverity(str, Enum):
91 """Severity levels for tenant escape path findings."""
93 CRITICAL = "critical" # Immediate data leak risk
94 HIGH = "high" # Likely data leak with specific conditions
95 MEDIUM = "medium" # Potential leak in edge cases
96 LOW = "low" # Best practice violation, no immediate risk
97 INFO = "info" # Informational finding
100@dataclass
101class EscapePathFinding:
102 """
103 A finding from tenant escape path analysis.
105 Attributes:
106 severity: Severity level of the finding
107 category: Category of the escape path (query, context, background, raw_sql)
108 description: Human-readable description of the issue
109 location: Code location or query fragment where issue was found
110 remediation: Suggested fix for the issue
111 compliance_impact: Affected compliance controls
112 """
114 severity: EscapePathSeverity
115 category: str
116 description: str
117 location: str
118 remediation: str
119 compliance_impact: List[str] = field(default_factory=list)
121 def __str__(self) -> str:
122 return (
123 f"[{self.severity.value.upper()}] {self.category}: {self.description}\n"
124 f" Location: {self.location}\n"
125 f" Remediation: {self.remediation}"
126 )
129class SessionFactoryProtocol(Protocol):
130 """Protocol for async session factory functions."""
132 async def __call__(self, tenant_id: str) -> AsyncSession:
133 """Create a session scoped to the given tenant."""
134 ...
137# =============================================================================
138# Core Assertion Functions
139# =============================================================================
142async def assert_tenant_isolation(
143 query: Union[Select, ClauseElement, str],
144 tenant_column: str = "tenant_id",
145 session: Optional[AsyncSession] = None,
146 strict: bool = True,
147 allowed_patterns: Optional[List[Pattern[str]]] = None,
148) -> None:
149 """
150 Assert that a SQLAlchemy query includes tenant filtering.
152 Raises TenantIsolationError if query could leak cross-tenant data.
154 Args:
155 query: SQLAlchemy Select statement, ClauseElement, or raw SQL string
156 tenant_column: Name of the tenant column (default: "tenant_id")
157 session: Optional session for query compilation (if needed)
158 strict: If True, require exact tenant_id match; if False, allow patterns
159 allowed_patterns: List of regex patterns for allowed queries without tenant filter
160 (e.g., system tables, public lookup tables)
162 Raises:
163 TenantIsolationError: If query is missing tenant filter
165 Example:
166 # This FAILS - no tenant filter
167 query = select(Item).where(Item.status == "active")
168 await assert_tenant_isolation(query) # Raises TenantIsolationError!
170 # This PASSES
171 query = select(Item).where(Item.tenant_id == tenant_id, Item.status == "active")
172 await assert_tenant_isolation(query) # OK
174 # Allow certain system queries
175 allowed = [re.compile(r"system_config"), re.compile(r"lookup_")]
176 query = select(SystemConfig) # No tenant_id column
177 await assert_tenant_isolation(query, allowed_patterns=allowed) # OK
179 Security Note:
180 This function performs STATIC analysis of the query string.
181 It cannot detect all possible bypass scenarios (e.g., dynamic SQL injection).
182 Use in combination with RLS policies for defense-in-depth.
183 """
184 # Convert query to string for analysis
185 if isinstance(query, str):
186 compiled = query
187 else:
188 try:
189 compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
190 except Exception:
191 # Fall back to simple string conversion
192 compiled = str(query)
194 compiled_lower = compiled.lower()
196 # Check allowed patterns first
197 if allowed_patterns:
198 for pattern in allowed_patterns:
199 if pattern.search(compiled):
200 logger.debug(f"Query allowed by pattern {pattern.pattern}: {compiled[:100]}...")
201 return
203 # Check for tenant filter presence
204 tenant_col_lower = tenant_column.lower()
206 # Patterns that indicate proper tenant filtering
207 tenant_filter_patterns = [
208 # Direct column reference in WHERE
209 rf"{tenant_col_lower}\s*=",
210 # Parameterized version
211 rf"{tenant_col_lower}\s*=\s*:",
212 # IN clause with tenant
213 rf"{tenant_col_lower}\s+in\s*\(",
214 # JOIN condition with tenant
215 rf"on\s+.*{tenant_col_lower}\s*=",
216 # Subquery correlation
217 rf"where\s+.*{tenant_col_lower}",
218 ]
220 has_tenant_filter = any(
221 re.search(pattern, compiled_lower) for pattern in tenant_filter_patterns
222 )
224 if not has_tenant_filter:
225 # Truncate long queries for error message
226 query_preview = compiled[:500] + "..." if len(compiled) > 500 else compiled
228 raise TenantIsolationError(
229 f"Query missing tenant isolation! Expected '{tenant_column}' filter.\n"
230 f"Query: {query_preview}\n\n"
231 f"REMEDIATION: Add .where({tenant_column} == tenant_id) to your query.\n"
232 f"COMPLIANCE IMPACT: SOC2 CC6.1, ISO27001 A.9.4.1, NIST AC-4"
233 )
235 logger.debug(f"Tenant isolation verified for query: {compiled[:100]}...")
238def assert_tenant_isolation_sync(
239 query: Union[Select, ClauseElement, str],
240 tenant_column: str = "tenant_id",
241 strict: bool = True,
242 allowed_patterns: Optional[List[Pattern[str]]] = None,
243) -> None:
244 """
245 Synchronous version of assert_tenant_isolation.
247 For use in non-async contexts or pytest fixtures.
248 """
249 # Run the async version in a new event loop
250 loop = asyncio.new_event_loop()
251 try:
252 loop.run_until_complete(
253 assert_tenant_isolation(
254 query=query,
255 tenant_column=tenant_column,
256 strict=strict,
257 allowed_patterns=allowed_patterns,
258 )
259 )
260 finally:
261 loop.close()
264# =============================================================================
265# Test Context Management
266# =============================================================================
269class TenantTestContext:
270 """
271 Context manager for testing tenant isolation.
273 Creates two test tenants and verifies data cannot leak between them.
274 Automatically sets PostgreSQL RLS session variables.
276 Example:
277 async with TenantTestContext(session) as ctx:
278 # Create data in tenant A (context starts here)
279 item_a = Item(name="Secret", tenant_id=ctx.tenant_a_id)
280 session.add(item_a)
281 await session.commit()
283 # Switch to tenant B and try to read
284 await ctx.switch_to_tenant_b()
286 # This query should return empty due to RLS
287 result = await session.execute(select(Item))
288 items = result.scalars().all()
289 assert item_a not in items, "CRITICAL: Tenant B can see Tenant A's data!"
291 Attributes:
292 tenant_a_id: UUID for test tenant A
293 tenant_b_id: UUID for test tenant B
294 current_tenant: Currently active tenant ID
295 session: Database session with RLS context
296 """
298 def __init__(
299 self,
300 session: AsyncSession,
301 tenant_a_id: Optional[str] = None,
302 tenant_b_id: Optional[str] = None,
303 session_variable: str = "app.current_tenant_id",
304 user_session_variable: str = "app.current_user_id",
305 auto_cleanup: bool = True,
306 ):
307 """
308 Initialize tenant test context.
310 Args:
311 session: SQLAlchemy AsyncSession to use for testing
312 tenant_a_id: Override tenant A ID (default: auto-generated)
313 tenant_b_id: Override tenant B ID (default: auto-generated)
314 session_variable: PostgreSQL session variable for tenant ID
315 user_session_variable: PostgreSQL session variable for user ID
316 auto_cleanup: Whether to reset context on exit
317 """
318 self.session = session
319 self.tenant_a_id = tenant_a_id or f"test-tenant-a-{uuid4().hex[:8]}"
320 self.tenant_b_id = tenant_b_id or f"test-tenant-b-{uuid4().hex[:8]}"
321 self.current_tenant = self.tenant_a_id
322 self.session_variable = session_variable
323 self.user_session_variable = user_session_variable
324 self.auto_cleanup = auto_cleanup
325 self._original_tenant: Optional[str] = None
326 self._context_history: List[Tuple[str, str]] = [] # (action, tenant_id)
328 async def __aenter__(self) -> "TenantTestContext":
329 """Enter context and set RLS for tenant A."""
330 # Store original tenant context if any
331 try:
332 result = await self.session.execute(
333 text(f"SELECT current_setting('{self.session_variable}', true)")
334 )
335 self._original_tenant = result.scalar()
336 except Exception:
337 self._original_tenant = None
339 # Set RLS context for tenant A
340 await self._set_tenant(self.tenant_a_id)
341 self._context_history.append(("enter", self.tenant_a_id))
343 logger.info(f"TenantTestContext initialized. Tenant A: {self.tenant_a_id}")
344 return self
346 async def __aexit__(
347 self,
348 exc_type: Optional[type],
349 exc_val: Optional[BaseException],
350 exc_tb: Optional[Any],
351 ) -> None:
352 """Exit context and optionally reset RLS."""
353 if self.auto_cleanup:
354 if self._original_tenant:
355 await self._set_tenant(self._original_tenant)
356 else:
357 await self.session.execute(text(f"RESET {self.session_variable}"))
358 await self.session.execute(text(f"RESET {self.user_session_variable}"))
360 self._context_history.append(("exit", self.current_tenant))
361 logger.info(f"TenantTestContext exited. History: {len(self._context_history)} actions")
363 async def _set_tenant(self, tenant_id: str) -> None:
364 """Set tenant context via PostgreSQL session variable."""
365 # Use parameterized query to prevent SQL injection
366 await self.session.execute(
367 text(f"SET LOCAL {self.session_variable} = :tenant_id"),
368 {"tenant_id": tenant_id},
369 )
370 self.current_tenant = tenant_id
371 _current_tenant_id.set(tenant_id)
373 async def switch_to_tenant_a(self) -> None:
374 """Switch to tenant A context."""
375 await self._set_tenant(self.tenant_a_id)
376 self._context_history.append(("switch_a", self.tenant_a_id))
377 logger.debug(f"Switched to tenant A: {self.tenant_a_id}")
379 async def switch_to_tenant_b(self) -> None:
380 """Switch to tenant B context."""
381 await self._set_tenant(self.tenant_b_id)
382 self._context_history.append(("switch_b", self.tenant_b_id))
383 logger.debug(f"Switched to tenant B: {self.tenant_b_id}")
385 async def switch_to_tenant(self, tenant_id: str) -> None:
386 """Switch to arbitrary tenant context (for advanced testing)."""
387 await self._set_tenant(tenant_id)
388 self._context_history.append(("switch_custom", tenant_id))
389 logger.debug(f"Switched to custom tenant: {tenant_id}")
391 async def clear_tenant_context(self) -> None:
392 """
393 Clear tenant context (simulate superuser/admin access).
395 WARNING: Use only for testing admin bypass scenarios.
396 """
397 await self.session.execute(text(f"RESET {self.session_variable}"))
398 self.current_tenant = ""
399 _current_tenant_id.set(None)
400 self._context_history.append(("clear", ""))
401 logger.warning("Tenant context cleared - operating without RLS filtering")
403 async def get_current_tenant(self) -> Optional[str]:
404 """Get the currently set tenant ID from PostgreSQL session."""
405 result = await self.session.execute(
406 text(f"SELECT current_setting('{self.session_variable}', true)")
407 )
408 value = result.scalar()
409 return value if value and value != "" else None
411 def get_context_history(self) -> List[Tuple[str, str]]:
412 """Get history of context switches for debugging."""
413 return self._context_history.copy()
416@asynccontextmanager
417async def tenant_test_context(
418 session: AsyncSession,
419 **kwargs: Any,
420) -> AsyncGenerator[TenantTestContext, None]:
421 """
422 Functional context manager for tenant isolation testing.
424 Alternative to using TenantTestContext directly.
426 Example:
427 async with tenant_test_context(session) as ctx:
428 # Create data in tenant A
429 ...
430 # Switch and verify isolation
431 await ctx.switch_to_tenant_b()
432 ...
433 """
434 ctx = TenantTestContext(session, **kwargs)
435 async with ctx:
436 yield ctx
439# =============================================================================
440# Background Task Context Preservation
441# =============================================================================
444class BackgroundTaskTenantContext:
445 """
446 Wrapper for background tasks that preserves tenant context.
448 CRITICAL: Background tasks lose request context by default!
449 FastAPI's BackgroundTasks runs after the response is sent,
450 meaning the original request's tenant context is lost.
452 Example:
453 # WRONG - loses tenant context
454 background_tasks.add_task(process_items)
456 # RIGHT - preserves tenant context
457 background_tasks.add_task(
458 BackgroundTaskTenantContext(tenant_id, session_factory).run(process_items)
459 )
461 For Celery/Redis Queue integration:
462 # In task definition
463 @celery_app.task
464 def process_items_task(tenant_id: str, item_ids: list):
465 async def inner():
466 async with BackgroundTaskTenantContext(tenant_id).get_session() as session:
467 await process_items(session, item_ids)
468 asyncio.run(inner())
470 Security Note:
471 Always pass tenant_id explicitly - never rely on context inheritance.
472 This ensures audit trails are maintained for background operations.
473 """
475 def __init__(
476 self,
477 tenant_id: str,
478 session_factory: Optional[SessionFactoryProtocol] = None,
479 user_id: Optional[str] = None,
480 correlation_id: Optional[str] = None,
481 ):
482 """
483 Initialize background task tenant context.
485 Args:
486 tenant_id: Tenant ID to scope the background task
487 session_factory: Async function that creates a session with tenant context
488 user_id: User ID for audit logging (optional)
489 correlation_id: Request correlation ID for tracing (optional)
490 """
491 self.tenant_id = tenant_id
492 self.session_factory = session_factory
493 self.user_id = user_id
494 self.correlation_id = correlation_id or uuid4().hex
496 def run(
497 self,
498 func: AsyncFunc,
499 *args: Any,
500 **kwargs: Any,
501 ) -> Callable[[], Coroutine[Any, Any, Any]]:
502 """
503 Wrap an async function to run with tenant context.
505 Args:
506 func: Async function to wrap
507 *args: Arguments to pass to the function
508 **kwargs: Keyword arguments to pass to the function
510 Returns:
511 Wrapped async function that can be added to BackgroundTasks
512 """
514 @functools.wraps(func)
515 async def wrapped() -> Any:
516 # Set context variable for tenant
517 _current_tenant_id.set(self.tenant_id)
519 logger.info(
520 f"Background task starting: {func.__name__} "
521 f"[tenant={self.tenant_id}, correlation={self.correlation_id}]"
522 )
524 try:
525 if self.session_factory:
526 # Use provided session factory
527 session = await self.session_factory(self.tenant_id)
528 return await func(*args, session=session, **kwargs)
529 else:
530 # Session should be provided in kwargs
531 return await func(*args, **kwargs)
532 except Exception as e:
533 logger.error(
534 f"Background task failed: {func.__name__} "
535 f"[tenant={self.tenant_id}, correlation={self.correlation_id}]: {e}"
536 )
537 raise
538 finally:
539 logger.info(
540 f"Background task completed: {func.__name__} "
541 f"[tenant={self.tenant_id}, correlation={self.correlation_id}]"
542 )
544 return wrapped
546 @asynccontextmanager
547 async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
548 """
549 Get a session with tenant context set.
551 Requires session_factory to be configured.
553 Example:
554 async with BackgroundTaskTenantContext(tenant_id, factory).get_session() as session:
555 result = await session.execute(select(Item))
556 """
557 if not self.session_factory:
558 raise ValueError(
559 "session_factory must be provided to use get_session(). "
560 "Either pass session_factory in constructor or provide session manually."
561 )
563 session = await self.session_factory(self.tenant_id)
564 try:
565 yield session
566 finally:
567 await session.close()
570def preserve_tenant_context(
571 tenant_id: str,
572 session_factory: Optional[SessionFactoryProtocol] = None,
573) -> Callable[[AsyncFunc], Callable[..., Coroutine[Any, Any, Any]]]:
574 """
575 Decorator to preserve tenant context in background tasks.
577 Example:
578 @preserve_tenant_context(tenant_id, session_factory)
579 async def process_items(session, item_ids):
580 ...
582 # Add to background tasks
583 background_tasks.add_task(process_items, item_ids)
584 """
586 def decorator(func: AsyncFunc) -> Callable[..., Coroutine[Any, Any, Any]]:
587 @functools.wraps(func)
588 async def wrapper(*args: Any, **kwargs: Any) -> Any:
589 ctx = BackgroundTaskTenantContext(tenant_id, session_factory)
590 return await ctx.run(func, *args, **kwargs)()
592 return wrapper
594 return decorator
597# =============================================================================
598# Escape Path Detection and Scanning
599# =============================================================================
602class TenantEscapePathScanner:
603 """
604 Scans code and queries for potential tenant isolation escape paths.
606 Use in CI/CD pipelines to detect security issues before deployment.
608 Detects:
609 - Queries without tenant filters
610 - Raw SQL that bypasses ORM
611 - Background tasks without context preservation
612 - Pagination without tenant scope
613 - Aggregations that leak tenant boundaries
614 - JOIN queries with missing tenant conditions
615 - Subqueries without tenant correlation
617 Example:
618 scanner = TenantEscapePathScanner()
620 # Scan a query
621 findings = scanner.scan_query(query_string)
623 # Scan a Python file
624 findings = scanner.scan_file("/path/to/repo.py")
626 # Scan entire directory
627 findings = scanner.scan_directory("/path/to/repo")
629 # Fail CI if critical findings
630 critical = [f for f in findings if f.severity == EscapePathSeverity.CRITICAL]
631 if critical:
632 sys.exit(1)
633 """
635 # Patterns that indicate potential escape paths
636 DANGEROUS_PATTERNS: List[Tuple[Pattern[str], EscapePathSeverity, str, str]] = [
637 # Raw SQL execution
638 (
639 re.compile(r"execute\s*\(\s*['\"]SELECT", re.IGNORECASE),
640 EscapePathSeverity.HIGH,
641 "raw_sql",
642 "Raw SELECT query detected. Ensure tenant filter is included.",
643 ),
644 (
645 re.compile(r"execute\s*\(\s*['\"]UPDATE", re.IGNORECASE),
646 EscapePathSeverity.CRITICAL,
647 "raw_sql",
648 "Raw UPDATE query detected. Could modify other tenants' data.",
649 ),
650 (
651 re.compile(r"execute\s*\(\s*['\"]DELETE", re.IGNORECASE),
652 EscapePathSeverity.CRITICAL,
653 "raw_sql",
654 "Raw DELETE query detected. Could delete other tenants' data.",
655 ),
656 # Pagination without tenant filter
657 (
658 re.compile(r"\.offset\s*\([^)]+\)\.limit\s*\([^)]+\)(?!.*tenant)", re.IGNORECASE),
659 EscapePathSeverity.HIGH,
660 "pagination",
661 "Pagination detected without visible tenant filter.",
662 ),
663 # Aggregation across tenants
664 (
665 re.compile(r"func\.(count|sum|avg|max|min)\s*\((?!.*tenant)", re.IGNORECASE),
666 EscapePathSeverity.MEDIUM,
667 "aggregation",
668 "Aggregation function without tenant filter could leak cross-tenant metrics.",
669 ),
670 # Background task without context
671 (
672 re.compile(r"add_task\s*\([^B]*\)(?!.*TenantContext)", re.IGNORECASE),
673 EscapePathSeverity.HIGH,
674 "background",
675 "BackgroundTask may lose tenant context. Use BackgroundTaskTenantContext.",
676 ),
677 # Session without RLS
678 (
679 re.compile(r"AsyncSession\s*\(\s*\)(?!.*set_tenant)", re.IGNORECASE),
680 EscapePathSeverity.MEDIUM,
681 "session",
682 "Session created without explicit RLS context setup.",
683 ),
684 # UNION queries (potential cross-tenant join)
685 (
686 re.compile(r"\bUNION\b(?!.*tenant)", re.IGNORECASE),
687 EscapePathSeverity.HIGH,
688 "union",
689 "UNION query detected. Ensure all subqueries have tenant filters.",
690 ),
691 # Subquery without correlation
692 (
693 re.compile(r"\(\s*SELECT[^)]+FROM[^)]+\)(?!.*tenant)", re.IGNORECASE),
694 EscapePathSeverity.MEDIUM,
695 "subquery",
696 "Subquery without visible tenant correlation.",
697 ),
698 ]
700 # Patterns that indicate SAFE tenant handling
701 SAFE_PATTERNS: List[Pattern[str]] = [
702 re.compile(r"tenant_id\s*=", re.IGNORECASE),
703 re.compile(r"\.filter\s*\([^)]*tenant", re.IGNORECASE),
704 re.compile(r"\.where\s*\([^)]*tenant", re.IGNORECASE),
705 re.compile(r"current_setting\s*\(['\"]app\.current_tenant", re.IGNORECASE),
706 re.compile(r"set_tenant_context", re.IGNORECASE),
707 re.compile(r"BackgroundTaskTenantContext", re.IGNORECASE),
708 re.compile(r"TenantTestContext", re.IGNORECASE),
709 ]
711 def __init__(
712 self,
713 tenant_column: str = "tenant_id",
714 custom_dangerous_patterns: Optional[List[Tuple[Pattern[str], EscapePathSeverity, str, str]]] = None,
715 custom_safe_patterns: Optional[List[Pattern[str]]] = None,
716 ignore_patterns: Optional[List[Pattern[str]]] = None,
717 ):
718 """
719 Initialize escape path scanner.
721 Args:
722 tenant_column: Name of tenant column to check for
723 custom_dangerous_patterns: Additional patterns to detect
724 custom_safe_patterns: Additional patterns that indicate safe handling
725 ignore_patterns: Patterns to ignore (e.g., test files, comments)
726 """
727 self.tenant_column = tenant_column
728 self.dangerous_patterns = list(self.DANGEROUS_PATTERNS)
729 self.safe_patterns = list(self.SAFE_PATTERNS)
730 self.ignore_patterns = ignore_patterns or []
732 if custom_dangerous_patterns:
733 self.dangerous_patterns.extend(custom_dangerous_patterns)
734 if custom_safe_patterns:
735 self.safe_patterns.extend(custom_safe_patterns)
737 def scan_query(self, query: str, context: str = "unknown") -> List[EscapePathFinding]:
738 """
739 Scan a query string for potential escape paths.
741 Args:
742 query: SQL query or SQLAlchemy query string
743 context: Context information (e.g., file:line)
745 Returns:
746 List of findings
747 """
748 findings: List[EscapePathFinding] = []
750 # Check if query is explicitly safe
751 for safe_pattern in self.safe_patterns:
752 if safe_pattern.search(query):
753 return [] # Query appears to have proper tenant handling
755 # Check for dangerous patterns
756 for pattern, severity, category, description in self.dangerous_patterns:
757 if pattern.search(query):
758 finding = EscapePathFinding(
759 severity=severity,
760 category=category,
761 description=description,
762 location=context,
763 remediation=f"Add {self.tenant_column} filter to query",
764 compliance_impact=["SOC2 CC6.1", "ISO27001 A.9.4", "NIST AC-4"],
765 )
766 findings.append(finding)
768 return findings
770 def scan_file(self, file_path: str) -> List[EscapePathFinding]:
771 """
772 Scan a Python file for potential escape paths.
774 Args:
775 file_path: Path to Python file
777 Returns:
778 List of findings with line numbers
779 """
780 findings: List[EscapePathFinding] = []
782 try:
783 with open(file_path, "r", encoding="utf-8") as f:
784 content = f.read()
785 except Exception as e:
786 logger.warning(f"Could not read file {file_path}: {e}")
787 return findings
789 # Skip ignored files
790 for ignore_pattern in self.ignore_patterns:
791 if ignore_pattern.search(file_path):
792 return findings
794 # Scan line by line for context
795 lines = content.split("\n")
796 for line_num, line in enumerate(lines, 1):
797 # Skip comments
798 stripped = line.strip()
799 if stripped.startswith("#") or stripped.startswith('"""') or stripped.startswith("'''"):
800 continue
802 line_findings = self.scan_query(line, f"{file_path}:{line_num}")
803 findings.extend(line_findings)
805 return findings
807 def scan_directory(
808 self,
809 directory: str,
810 file_patterns: Optional[List[str]] = None,
811 exclude_patterns: Optional[List[str]] = None,
812 ) -> List[EscapePathFinding]:
813 """
814 Scan a directory recursively for potential escape paths.
816 Args:
817 directory: Root directory to scan
818 file_patterns: Glob patterns for files to include (default: ["*.py"])
819 exclude_patterns: Glob patterns for files to exclude
821 Returns:
822 List of all findings across all files
823 """
824 import glob
825 import os
827 findings: List[EscapePathFinding] = []
828 file_patterns = file_patterns or ["**/*.py"]
829 exclude_patterns = exclude_patterns or ["**/test_*.py", "**/tests/**", "**/__pycache__/**"]
831 for pattern in file_patterns:
832 full_pattern = os.path.join(directory, pattern)
833 for file_path in glob.glob(full_pattern, recursive=True):
834 # Check exclusions
835 excluded = False
836 for exclude in exclude_patterns:
837 if glob.fnmatch.fnmatch(file_path, exclude):
838 excluded = True
839 break
841 if not excluded:
842 findings.extend(self.scan_file(file_path))
844 return findings
846 def generate_report(
847 self,
848 findings: List[EscapePathFinding],
849 format: str = "text",
850 ) -> str:
851 """
852 Generate a report from scan findings.
854 Args:
855 findings: List of findings to report
856 format: Output format ("text", "json", "markdown")
858 Returns:
859 Formatted report string
860 """
861 if format == "json":
862 import json
864 return json.dumps(
865 [
866 {
867 "severity": f.severity.value,
868 "category": f.category,
869 "description": f.description,
870 "location": f.location,
871 "remediation": f.remediation,
872 "compliance_impact": f.compliance_impact,
873 }
874 for f in findings
875 ],
876 indent=2,
877 )
879 elif format == "markdown":
880 lines = ["# Tenant Isolation Escape Path Report\n"]
882 # Group by severity
883 by_severity: Dict[EscapePathSeverity, List[EscapePathFinding]] = {}
884 for f in findings:
885 by_severity.setdefault(f.severity, []).append(f)
887 for severity in [
888 EscapePathSeverity.CRITICAL,
889 EscapePathSeverity.HIGH,
890 EscapePathSeverity.MEDIUM,
891 EscapePathSeverity.LOW,
892 ]:
893 if severity in by_severity:
894 lines.append(f"\n## {severity.value.upper()} ({len(by_severity[severity])})\n")
895 for f in by_severity[severity]:
896 lines.append(f"### {f.category}\n")
897 lines.append(f"- **Location**: `{f.location}`\n")
898 lines.append(f"- **Description**: {f.description}\n")
899 lines.append(f"- **Remediation**: {f.remediation}\n")
900 lines.append(f"- **Compliance**: {', '.join(f.compliance_impact)}\n")
902 return "\n".join(lines)
904 else: # text format
905 lines = ["=" * 60, "TENANT ISOLATION ESCAPE PATH REPORT", "=" * 60, ""]
907 critical = [f for f in findings if f.severity == EscapePathSeverity.CRITICAL]
908 high = [f for f in findings if f.severity == EscapePathSeverity.HIGH]
909 medium = [f for f in findings if f.severity == EscapePathSeverity.MEDIUM]
910 low = [f for f in findings if f.severity == EscapePathSeverity.LOW]
912 lines.append(f"CRITICAL: {len(critical)} HIGH: {len(high)} MEDIUM: {len(medium)} LOW: {len(low)}")
913 lines.append("")
915 for f in findings:
916 lines.append(str(f))
917 lines.append("-" * 40)
919 return "\n".join(lines)
922# =============================================================================
923# CI/CD Integration Utilities
924# =============================================================================
927def ci_fail_on_findings(
928 findings: List[EscapePathFinding],
929 fail_on: Set[EscapePathSeverity] = None,
930) -> int:
931 """
932 Return exit code for CI/CD based on findings.
934 Args:
935 findings: List of scan findings
936 fail_on: Set of severities that should cause failure
937 (default: CRITICAL and HIGH)
939 Returns:
940 0 if no critical findings, 1 otherwise (for sys.exit())
942 Example:
943 scanner = TenantEscapePathScanner()
944 findings = scanner.scan_directory("./src")
945 sys.exit(ci_fail_on_findings(findings))
946 """
947 if fail_on is None:
948 fail_on = {EscapePathSeverity.CRITICAL, EscapePathSeverity.HIGH}
950 failing_findings = [f for f in findings if f.severity in fail_on]
952 if failing_findings:
953 print(f"CI FAILED: Found {len(failing_findings)} critical/high severity findings")
954 for f in failing_findings:
955 print(f" - {f.severity.value.upper()}: {f.description} at {f.location}")
956 return 1
958 print(f"CI PASSED: No critical/high findings ({len(findings)} total findings)")
959 return 0
962# =============================================================================
963# Pytest Fixtures and Markers
964# =============================================================================
967def pytest_configure(config: Any) -> None:
968 """
969 Register pytest markers for tenant isolation tests.
971 Add to conftest.py:
972 from netrun.rbac.testing import pytest_configure
973 """
974 config.addinivalue_line(
975 "markers",
976 "tenant_isolation: Mark test as a tenant isolation contract test",
977 )
978 config.addinivalue_line(
979 "markers",
980 "escape_path: Mark test as testing a specific escape path scenario",
981 )
984def tenant_isolation_test(func: AsyncFunc) -> AsyncFunc:
985 """
986 Decorator to mark a test as a tenant isolation contract test.
988 Adds additional validation and logging around the test.
990 Example:
991 @tenant_isolation_test
992 async def test_cross_tenant_read_impossible(self, db_session):
993 ...
994 """
996 @functools.wraps(func)
997 async def wrapper(*args: Any, **kwargs: Any) -> Any:
998 logger.info(f"Starting tenant isolation test: {func.__name__}")
999 try:
1000 result = await func(*args, **kwargs)
1001 logger.info(f"PASSED: {func.__name__}")
1002 return result
1003 except TenantIsolationError as e:
1004 logger.error(f"FAILED (Isolation Error): {func.__name__}: {e}")
1005 raise
1006 except AssertionError as e:
1007 logger.error(f"FAILED (Assertion): {func.__name__}: {e}")
1008 raise
1009 except Exception as e:
1010 logger.error(f"FAILED (Exception): {func.__name__}: {e}")
1011 raise
1013 return wrapper # type: ignore
1016# =============================================================================
1017# Compliance Documentation
1018# =============================================================================
1021COMPLIANCE_MAPPING = {
1022 "SOC2": {
1023 "CC6.1": "Logical and Physical Access Controls",
1024 "CC6.2": "Role-Based Access Control",
1025 "CC6.3": "Segregation of Duties",
1026 },
1027 "ISO27001": {
1028 "A.9.1": "Business Requirements of Access Control",
1029 "A.9.4": "System and Application Access Control",
1030 "A.9.4.1": "Information Access Restriction",
1031 },
1032 "NIST": {
1033 "AC-4": "Information Flow Enforcement",
1034 "AC-5": "Separation of Duties",
1035 "AC-6": "Least Privilege",
1036 },
1037}
1040def get_compliance_documentation() -> str:
1041 """
1042 Get documentation of compliance controls addressed by tenant isolation testing.
1044 Returns:
1045 Formatted compliance documentation
1046 """
1047 lines = [
1048 "# Tenant Isolation Testing - Compliance Mapping",
1049 "",
1050 "The tenant isolation testing utilities in this module address the following",
1051 "compliance requirements:",
1052 "",
1053 ]
1055 for framework, controls in COMPLIANCE_MAPPING.items():
1056 lines.append(f"## {framework}")
1057 for control_id, description in controls.items():
1058 lines.append(f"- **{control_id}**: {description}")
1059 lines.append("")
1061 lines.extend(
1062 [
1063 "## Testing Requirements",
1064 "",
1065 "To maintain compliance, the following tests MUST pass before any release:",
1066 "",
1067 "1. **test_cross_tenant_read_impossible** - Proves Tenant B cannot read Tenant A's data",
1068 "2. **test_cross_tenant_write_impossible** - Proves Tenant B cannot modify Tenant A's data",
1069 "3. **test_query_without_tenant_filter_fails** - Ensures queries are validated",
1070 "4. **test_pagination_includes_tenant_filter** - Prevents paginated data leaks",
1071 "5. **test_background_task_preserves_tenant** - Ensures async context is maintained",
1072 "",
1073 "## CI/CD Integration",
1074 "",
1075 "Run escape path scanning as part of your CI pipeline:",
1076 "",
1077 "```python",
1078 "from netrun.rbac.testing import TenantEscapePathScanner, ci_fail_on_findings",
1079 "",
1080 "scanner = TenantEscapePathScanner()",
1081 "findings = scanner.scan_directory('./src')",
1082 "sys.exit(ci_fail_on_findings(findings))",
1083 "```",
1084 ]
1085 )
1087 return "\n".join(lines)
1090# =============================================================================
1091# Module Exports
1092# =============================================================================
1094__all__ = [
1095 # Exceptions (re-exported for convenience)
1096 "TenantIsolationError",
1097 # Core assertions
1098 "assert_tenant_isolation",
1099 "assert_tenant_isolation_sync",
1100 # Test context
1101 "TenantTestContext",
1102 "tenant_test_context",
1103 # Background task handling
1104 "BackgroundTaskTenantContext",
1105 "preserve_tenant_context",
1106 # Escape path detection
1107 "TenantEscapePathScanner",
1108 "EscapePathSeverity",
1109 "EscapePathFinding",
1110 # CI/CD utilities
1111 "ci_fail_on_findings",
1112 # Pytest integration
1113 "pytest_configure",
1114 "tenant_isolation_test",
1115 # Compliance
1116 "get_compliance_documentation",
1117 "COMPLIANCE_MAPPING",
1118]