Coverage for netrun / rbac / testing.py: 24%

294 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-18 22:20 +0000

1""" 

2Tenant Isolation Contract Testing Utilities for Multi-Tenant Applications. 

3 

4Provides pgTAP-style assertions for verifying multi-tenant data isolation. 

5Use in integration tests to prove cross-tenant data access is impossible. 

6 

7This module is CRITICAL for security compliance (SOC2, ISO27001, NIST). 

8 

9Features: 

10- Query analysis to detect missing tenant filters 

11- Test context management for multi-tenant scenarios 

12- Background task tenant context preservation 

13- Escape path detection and prevention 

14- CI/CD integration utilities 

15 

16Usage: 

17 from netrun.rbac.testing import ( 

18 TenantIsolationError, 

19 assert_tenant_isolation, 

20 TenantTestContext, 

21 BackgroundTaskTenantContext, 

22 TenantEscapePathScanner, 

23 ) 

24 

25 # Assert query includes tenant filter 

26 query = select(Item).where(Item.tenant_id == tenant_id) 

27 await assert_tenant_isolation(query) 

28 

29 # Test cross-tenant isolation 

30 async with TenantTestContext(session) as ctx: 

31 # Create data in tenant A 

32 item = Item(name="Secret", tenant_id=ctx.tenant_a_id) 

33 session.add(item) 

34 

35 # Switch to tenant B and verify isolation 

36 await ctx.switch_to_tenant_b() 

37 items = await session.execute(select(Item)) 

38 assert len(items.scalars().all()) == 0 

39 

40Security Level: CRITICAL 

41Compliance: SOC2 CC6.1, ISO27001 A.9.4, NIST AC-4 

42""" 

43 

44from __future__ import annotations 

45 

46import asyncio 

47import contextvars 

48import functools 

49import logging 

50import re 

51import warnings 

52from contextlib import asynccontextmanager 

53from dataclasses import dataclass, field 

54from enum import Enum 

55from typing import ( 

56 Any, 

57 AsyncGenerator, 

58 Callable, 

59 Coroutine, 

60 Dict, 

61 List, 

62 Optional, 

63 Pattern, 

64 Protocol, 

65 Set, 

66 Tuple, 

67 TypeVar, 

68 Union, 

69) 

70from uuid import uuid4 

71 

72from sqlalchemy import Select, text 

73from sqlalchemy.ext.asyncio import AsyncSession 

74from sqlalchemy.sql import ClauseElement 

75 

76from .exceptions import TenantIsolationError 

77 

78logger = logging.getLogger(__name__) 

79 

80# Type variables 

81T = TypeVar("T") 

82AsyncFunc = TypeVar("AsyncFunc", bound=Callable[..., Coroutine[Any, Any, Any]]) 

83 

84# Context variable for current tenant in async context 

85_current_tenant_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( 

86 "current_tenant_id", default=None 

87) 

88 

89 

90class EscapePathSeverity(str, Enum): 

91 """Severity levels for tenant escape path findings.""" 

92 

93 CRITICAL = "critical" # Immediate data leak risk 

94 HIGH = "high" # Likely data leak with specific conditions 

95 MEDIUM = "medium" # Potential leak in edge cases 

96 LOW = "low" # Best practice violation, no immediate risk 

97 INFO = "info" # Informational finding 

98 

99 

100@dataclass 

101class EscapePathFinding: 

102 """ 

103 A finding from tenant escape path analysis. 

104 

105 Attributes: 

106 severity: Severity level of the finding 

107 category: Category of the escape path (query, context, background, raw_sql) 

108 description: Human-readable description of the issue 

109 location: Code location or query fragment where issue was found 

110 remediation: Suggested fix for the issue 

111 compliance_impact: Affected compliance controls 

112 """ 

113 

114 severity: EscapePathSeverity 

115 category: str 

116 description: str 

117 location: str 

118 remediation: str 

119 compliance_impact: List[str] = field(default_factory=list) 

120 

121 def __str__(self) -> str: 

122 return ( 

123 f"[{self.severity.value.upper()}] {self.category}: {self.description}\n" 

124 f" Location: {self.location}\n" 

125 f" Remediation: {self.remediation}" 

126 ) 

127 

128 

129class SessionFactoryProtocol(Protocol): 

130 """Protocol for async session factory functions.""" 

131 

132 async def __call__(self, tenant_id: str) -> AsyncSession: 

133 """Create a session scoped to the given tenant.""" 

134 ... 

135 

136 

137# ============================================================================= 

138# Core Assertion Functions 

139# ============================================================================= 

140 

141 

142async def assert_tenant_isolation( 

143 query: Union[Select, ClauseElement, str], 

144 tenant_column: str = "tenant_id", 

145 session: Optional[AsyncSession] = None, 

146 strict: bool = True, 

147 allowed_patterns: Optional[List[Pattern[str]]] = None, 

148) -> None: 

149 """ 

150 Assert that a SQLAlchemy query includes tenant filtering. 

151 

152 Raises TenantIsolationError if query could leak cross-tenant data. 

153 

154 Args: 

155 query: SQLAlchemy Select statement, ClauseElement, or raw SQL string 

156 tenant_column: Name of the tenant column (default: "tenant_id") 

157 session: Optional session for query compilation (if needed) 

158 strict: If True, require exact tenant_id match; if False, allow patterns 

159 allowed_patterns: List of regex patterns for allowed queries without tenant filter 

160 (e.g., system tables, public lookup tables) 

161 

162 Raises: 

163 TenantIsolationError: If query is missing tenant filter 

164 

165 Example: 

166 # This FAILS - no tenant filter 

167 query = select(Item).where(Item.status == "active") 

168 await assert_tenant_isolation(query) # Raises TenantIsolationError! 

169 

170 # This PASSES 

171 query = select(Item).where(Item.tenant_id == tenant_id, Item.status == "active") 

172 await assert_tenant_isolation(query) # OK 

173 

174 # Allow certain system queries 

175 allowed = [re.compile(r"system_config"), re.compile(r"lookup_")] 

176 query = select(SystemConfig) # No tenant_id column 

177 await assert_tenant_isolation(query, allowed_patterns=allowed) # OK 

178 

179 Security Note: 

180 This function performs STATIC analysis of the query string. 

181 It cannot detect all possible bypass scenarios (e.g., dynamic SQL injection). 

182 Use in combination with RLS policies for defense-in-depth. 

183 """ 

184 # Convert query to string for analysis 

185 if isinstance(query, str): 

186 compiled = query 

187 else: 

188 try: 

189 compiled = str(query.compile(compile_kwargs={"literal_binds": True})) 

190 except Exception: 

191 # Fall back to simple string conversion 

192 compiled = str(query) 

193 

194 compiled_lower = compiled.lower() 

195 

196 # Check allowed patterns first 

197 if allowed_patterns: 

198 for pattern in allowed_patterns: 

199 if pattern.search(compiled): 

200 logger.debug(f"Query allowed by pattern {pattern.pattern}: {compiled[:100]}...") 

201 return 

202 

203 # Check for tenant filter presence 

204 tenant_col_lower = tenant_column.lower() 

205 

206 # Patterns that indicate proper tenant filtering 

207 tenant_filter_patterns = [ 

208 # Direct column reference in WHERE 

209 rf"{tenant_col_lower}\s*=", 

210 # Parameterized version 

211 rf"{tenant_col_lower}\s*=\s*:", 

212 # IN clause with tenant 

213 rf"{tenant_col_lower}\s+in\s*\(", 

214 # JOIN condition with tenant 

215 rf"on\s+.*{tenant_col_lower}\s*=", 

216 # Subquery correlation 

217 rf"where\s+.*{tenant_col_lower}", 

218 ] 

219 

220 has_tenant_filter = any( 

221 re.search(pattern, compiled_lower) for pattern in tenant_filter_patterns 

222 ) 

223 

224 if not has_tenant_filter: 

225 # Truncate long queries for error message 

226 query_preview = compiled[:500] + "..." if len(compiled) > 500 else compiled 

227 

228 raise TenantIsolationError( 

229 f"Query missing tenant isolation! Expected '{tenant_column}' filter.\n" 

230 f"Query: {query_preview}\n\n" 

231 f"REMEDIATION: Add .where({tenant_column} == tenant_id) to your query.\n" 

232 f"COMPLIANCE IMPACT: SOC2 CC6.1, ISO27001 A.9.4.1, NIST AC-4" 

233 ) 

234 

235 logger.debug(f"Tenant isolation verified for query: {compiled[:100]}...") 

236 

237 

238def assert_tenant_isolation_sync( 

239 query: Union[Select, ClauseElement, str], 

240 tenant_column: str = "tenant_id", 

241 strict: bool = True, 

242 allowed_patterns: Optional[List[Pattern[str]]] = None, 

243) -> None: 

244 """ 

245 Synchronous version of assert_tenant_isolation. 

246 

247 For use in non-async contexts or pytest fixtures. 

248 """ 

249 # Run the async version in a new event loop 

250 loop = asyncio.new_event_loop() 

251 try: 

252 loop.run_until_complete( 

253 assert_tenant_isolation( 

254 query=query, 

255 tenant_column=tenant_column, 

256 strict=strict, 

257 allowed_patterns=allowed_patterns, 

258 ) 

259 ) 

260 finally: 

261 loop.close() 

262 

263 

264# ============================================================================= 

265# Test Context Management 

266# ============================================================================= 

267 

268 

269class TenantTestContext: 

270 """ 

271 Context manager for testing tenant isolation. 

272 

273 Creates two test tenants and verifies data cannot leak between them. 

274 Automatically sets PostgreSQL RLS session variables. 

275 

276 Example: 

277 async with TenantTestContext(session) as ctx: 

278 # Create data in tenant A (context starts here) 

279 item_a = Item(name="Secret", tenant_id=ctx.tenant_a_id) 

280 session.add(item_a) 

281 await session.commit() 

282 

283 # Switch to tenant B and try to read 

284 await ctx.switch_to_tenant_b() 

285 

286 # This query should return empty due to RLS 

287 result = await session.execute(select(Item)) 

288 items = result.scalars().all() 

289 assert item_a not in items, "CRITICAL: Tenant B can see Tenant A's data!" 

290 

291 Attributes: 

292 tenant_a_id: UUID for test tenant A 

293 tenant_b_id: UUID for test tenant B 

294 current_tenant: Currently active tenant ID 

295 session: Database session with RLS context 

296 """ 

297 

298 def __init__( 

299 self, 

300 session: AsyncSession, 

301 tenant_a_id: Optional[str] = None, 

302 tenant_b_id: Optional[str] = None, 

303 session_variable: str = "app.current_tenant_id", 

304 user_session_variable: str = "app.current_user_id", 

305 auto_cleanup: bool = True, 

306 ): 

307 """ 

308 Initialize tenant test context. 

309 

310 Args: 

311 session: SQLAlchemy AsyncSession to use for testing 

312 tenant_a_id: Override tenant A ID (default: auto-generated) 

313 tenant_b_id: Override tenant B ID (default: auto-generated) 

314 session_variable: PostgreSQL session variable for tenant ID 

315 user_session_variable: PostgreSQL session variable for user ID 

316 auto_cleanup: Whether to reset context on exit 

317 """ 

318 self.session = session 

319 self.tenant_a_id = tenant_a_id or f"test-tenant-a-{uuid4().hex[:8]}" 

320 self.tenant_b_id = tenant_b_id or f"test-tenant-b-{uuid4().hex[:8]}" 

321 self.current_tenant = self.tenant_a_id 

322 self.session_variable = session_variable 

323 self.user_session_variable = user_session_variable 

324 self.auto_cleanup = auto_cleanup 

325 self._original_tenant: Optional[str] = None 

326 self._context_history: List[Tuple[str, str]] = [] # (action, tenant_id) 

327 

328 async def __aenter__(self) -> "TenantTestContext": 

329 """Enter context and set RLS for tenant A.""" 

330 # Store original tenant context if any 

331 try: 

332 result = await self.session.execute( 

333 text(f"SELECT current_setting('{self.session_variable}', true)") 

334 ) 

335 self._original_tenant = result.scalar() 

336 except Exception: 

337 self._original_tenant = None 

338 

339 # Set RLS context for tenant A 

340 await self._set_tenant(self.tenant_a_id) 

341 self._context_history.append(("enter", self.tenant_a_id)) 

342 

343 logger.info(f"TenantTestContext initialized. Tenant A: {self.tenant_a_id}") 

344 return self 

345 

346 async def __aexit__( 

347 self, 

348 exc_type: Optional[type], 

349 exc_val: Optional[BaseException], 

350 exc_tb: Optional[Any], 

351 ) -> None: 

352 """Exit context and optionally reset RLS.""" 

353 if self.auto_cleanup: 

354 if self._original_tenant: 

355 await self._set_tenant(self._original_tenant) 

356 else: 

357 await self.session.execute(text(f"RESET {self.session_variable}")) 

358 await self.session.execute(text(f"RESET {self.user_session_variable}")) 

359 

360 self._context_history.append(("exit", self.current_tenant)) 

361 logger.info(f"TenantTestContext exited. History: {len(self._context_history)} actions") 

362 

363 async def _set_tenant(self, tenant_id: str) -> None: 

364 """Set tenant context via PostgreSQL session variable.""" 

365 # Use parameterized query to prevent SQL injection 

366 await self.session.execute( 

367 text(f"SET LOCAL {self.session_variable} = :tenant_id"), 

368 {"tenant_id": tenant_id}, 

369 ) 

370 self.current_tenant = tenant_id 

371 _current_tenant_id.set(tenant_id) 

372 

373 async def switch_to_tenant_a(self) -> None: 

374 """Switch to tenant A context.""" 

375 await self._set_tenant(self.tenant_a_id) 

376 self._context_history.append(("switch_a", self.tenant_a_id)) 

377 logger.debug(f"Switched to tenant A: {self.tenant_a_id}") 

378 

379 async def switch_to_tenant_b(self) -> None: 

380 """Switch to tenant B context.""" 

381 await self._set_tenant(self.tenant_b_id) 

382 self._context_history.append(("switch_b", self.tenant_b_id)) 

383 logger.debug(f"Switched to tenant B: {self.tenant_b_id}") 

384 

385 async def switch_to_tenant(self, tenant_id: str) -> None: 

386 """Switch to arbitrary tenant context (for advanced testing).""" 

387 await self._set_tenant(tenant_id) 

388 self._context_history.append(("switch_custom", tenant_id)) 

389 logger.debug(f"Switched to custom tenant: {tenant_id}") 

390 

391 async def clear_tenant_context(self) -> None: 

392 """ 

393 Clear tenant context (simulate superuser/admin access). 

394 

395 WARNING: Use only for testing admin bypass scenarios. 

396 """ 

397 await self.session.execute(text(f"RESET {self.session_variable}")) 

398 self.current_tenant = "" 

399 _current_tenant_id.set(None) 

400 self._context_history.append(("clear", "")) 

401 logger.warning("Tenant context cleared - operating without RLS filtering") 

402 

403 async def get_current_tenant(self) -> Optional[str]: 

404 """Get the currently set tenant ID from PostgreSQL session.""" 

405 result = await self.session.execute( 

406 text(f"SELECT current_setting('{self.session_variable}', true)") 

407 ) 

408 value = result.scalar() 

409 return value if value and value != "" else None 

410 

411 def get_context_history(self) -> List[Tuple[str, str]]: 

412 """Get history of context switches for debugging.""" 

413 return self._context_history.copy() 

414 

415 

416@asynccontextmanager 

417async def tenant_test_context( 

418 session: AsyncSession, 

419 **kwargs: Any, 

420) -> AsyncGenerator[TenantTestContext, None]: 

421 """ 

422 Functional context manager for tenant isolation testing. 

423 

424 Alternative to using TenantTestContext directly. 

425 

426 Example: 

427 async with tenant_test_context(session) as ctx: 

428 # Create data in tenant A 

429 ... 

430 # Switch and verify isolation 

431 await ctx.switch_to_tenant_b() 

432 ... 

433 """ 

434 ctx = TenantTestContext(session, **kwargs) 

435 async with ctx: 

436 yield ctx 

437 

438 

439# ============================================================================= 

440# Background Task Context Preservation 

441# ============================================================================= 

442 

443 

444class BackgroundTaskTenantContext: 

445 """ 

446 Wrapper for background tasks that preserves tenant context. 

447 

448 CRITICAL: Background tasks lose request context by default! 

449 FastAPI's BackgroundTasks runs after the response is sent, 

450 meaning the original request's tenant context is lost. 

451 

452 Example: 

453 # WRONG - loses tenant context 

454 background_tasks.add_task(process_items) 

455 

456 # RIGHT - preserves tenant context 

457 background_tasks.add_task( 

458 BackgroundTaskTenantContext(tenant_id, session_factory).run(process_items) 

459 ) 

460 

461 For Celery/Redis Queue integration: 

462 # In task definition 

463 @celery_app.task 

464 def process_items_task(tenant_id: str, item_ids: list): 

465 async def inner(): 

466 async with BackgroundTaskTenantContext(tenant_id).get_session() as session: 

467 await process_items(session, item_ids) 

468 asyncio.run(inner()) 

469 

470 Security Note: 

471 Always pass tenant_id explicitly - never rely on context inheritance. 

472 This ensures audit trails are maintained for background operations. 

473 """ 

474 

475 def __init__( 

476 self, 

477 tenant_id: str, 

478 session_factory: Optional[SessionFactoryProtocol] = None, 

479 user_id: Optional[str] = None, 

480 correlation_id: Optional[str] = None, 

481 ): 

482 """ 

483 Initialize background task tenant context. 

484 

485 Args: 

486 tenant_id: Tenant ID to scope the background task 

487 session_factory: Async function that creates a session with tenant context 

488 user_id: User ID for audit logging (optional) 

489 correlation_id: Request correlation ID for tracing (optional) 

490 """ 

491 self.tenant_id = tenant_id 

492 self.session_factory = session_factory 

493 self.user_id = user_id 

494 self.correlation_id = correlation_id or uuid4().hex 

495 

496 def run( 

497 self, 

498 func: AsyncFunc, 

499 *args: Any, 

500 **kwargs: Any, 

501 ) -> Callable[[], Coroutine[Any, Any, Any]]: 

502 """ 

503 Wrap an async function to run with tenant context. 

504 

505 Args: 

506 func: Async function to wrap 

507 *args: Arguments to pass to the function 

508 **kwargs: Keyword arguments to pass to the function 

509 

510 Returns: 

511 Wrapped async function that can be added to BackgroundTasks 

512 """ 

513 

514 @functools.wraps(func) 

515 async def wrapped() -> Any: 

516 # Set context variable for tenant 

517 _current_tenant_id.set(self.tenant_id) 

518 

519 logger.info( 

520 f"Background task starting: {func.__name__} " 

521 f"[tenant={self.tenant_id}, correlation={self.correlation_id}]" 

522 ) 

523 

524 try: 

525 if self.session_factory: 

526 # Use provided session factory 

527 session = await self.session_factory(self.tenant_id) 

528 return await func(*args, session=session, **kwargs) 

529 else: 

530 # Session should be provided in kwargs 

531 return await func(*args, **kwargs) 

532 except Exception as e: 

533 logger.error( 

534 f"Background task failed: {func.__name__} " 

535 f"[tenant={self.tenant_id}, correlation={self.correlation_id}]: {e}" 

536 ) 

537 raise 

538 finally: 

539 logger.info( 

540 f"Background task completed: {func.__name__} " 

541 f"[tenant={self.tenant_id}, correlation={self.correlation_id}]" 

542 ) 

543 

544 return wrapped 

545 

546 @asynccontextmanager 

547 async def get_session(self) -> AsyncGenerator[AsyncSession, None]: 

548 """ 

549 Get a session with tenant context set. 

550 

551 Requires session_factory to be configured. 

552 

553 Example: 

554 async with BackgroundTaskTenantContext(tenant_id, factory).get_session() as session: 

555 result = await session.execute(select(Item)) 

556 """ 

557 if not self.session_factory: 

558 raise ValueError( 

559 "session_factory must be provided to use get_session(). " 

560 "Either pass session_factory in constructor or provide session manually." 

561 ) 

562 

563 session = await self.session_factory(self.tenant_id) 

564 try: 

565 yield session 

566 finally: 

567 await session.close() 

568 

569 

570def preserve_tenant_context( 

571 tenant_id: str, 

572 session_factory: Optional[SessionFactoryProtocol] = None, 

573) -> Callable[[AsyncFunc], Callable[..., Coroutine[Any, Any, Any]]]: 

574 """ 

575 Decorator to preserve tenant context in background tasks. 

576 

577 Example: 

578 @preserve_tenant_context(tenant_id, session_factory) 

579 async def process_items(session, item_ids): 

580 ... 

581 

582 # Add to background tasks 

583 background_tasks.add_task(process_items, item_ids) 

584 """ 

585 

586 def decorator(func: AsyncFunc) -> Callable[..., Coroutine[Any, Any, Any]]: 

587 @functools.wraps(func) 

588 async def wrapper(*args: Any, **kwargs: Any) -> Any: 

589 ctx = BackgroundTaskTenantContext(tenant_id, session_factory) 

590 return await ctx.run(func, *args, **kwargs)() 

591 

592 return wrapper 

593 

594 return decorator 

595 

596 

597# ============================================================================= 

598# Escape Path Detection and Scanning 

599# ============================================================================= 

600 

601 

602class TenantEscapePathScanner: 

603 """ 

604 Scans code and queries for potential tenant isolation escape paths. 

605 

606 Use in CI/CD pipelines to detect security issues before deployment. 

607 

608 Detects: 

609 - Queries without tenant filters 

610 - Raw SQL that bypasses ORM 

611 - Background tasks without context preservation 

612 - Pagination without tenant scope 

613 - Aggregations that leak tenant boundaries 

614 - JOIN queries with missing tenant conditions 

615 - Subqueries without tenant correlation 

616 

617 Example: 

618 scanner = TenantEscapePathScanner() 

619 

620 # Scan a query 

621 findings = scanner.scan_query(query_string) 

622 

623 # Scan a Python file 

624 findings = scanner.scan_file("/path/to/repo.py") 

625 

626 # Scan entire directory 

627 findings = scanner.scan_directory("/path/to/repo") 

628 

629 # Fail CI if critical findings 

630 critical = [f for f in findings if f.severity == EscapePathSeverity.CRITICAL] 

631 if critical: 

632 sys.exit(1) 

633 """ 

634 

635 # Patterns that indicate potential escape paths 

636 DANGEROUS_PATTERNS: List[Tuple[Pattern[str], EscapePathSeverity, str, str]] = [ 

637 # Raw SQL execution 

638 ( 

639 re.compile(r"execute\s*\(\s*['\"]SELECT", re.IGNORECASE), 

640 EscapePathSeverity.HIGH, 

641 "raw_sql", 

642 "Raw SELECT query detected. Ensure tenant filter is included.", 

643 ), 

644 ( 

645 re.compile(r"execute\s*\(\s*['\"]UPDATE", re.IGNORECASE), 

646 EscapePathSeverity.CRITICAL, 

647 "raw_sql", 

648 "Raw UPDATE query detected. Could modify other tenants' data.", 

649 ), 

650 ( 

651 re.compile(r"execute\s*\(\s*['\"]DELETE", re.IGNORECASE), 

652 EscapePathSeverity.CRITICAL, 

653 "raw_sql", 

654 "Raw DELETE query detected. Could delete other tenants' data.", 

655 ), 

656 # Pagination without tenant filter 

657 ( 

658 re.compile(r"\.offset\s*\([^)]+\)\.limit\s*\([^)]+\)(?!.*tenant)", re.IGNORECASE), 

659 EscapePathSeverity.HIGH, 

660 "pagination", 

661 "Pagination detected without visible tenant filter.", 

662 ), 

663 # Aggregation across tenants 

664 ( 

665 re.compile(r"func\.(count|sum|avg|max|min)\s*\((?!.*tenant)", re.IGNORECASE), 

666 EscapePathSeverity.MEDIUM, 

667 "aggregation", 

668 "Aggregation function without tenant filter could leak cross-tenant metrics.", 

669 ), 

670 # Background task without context 

671 ( 

672 re.compile(r"add_task\s*\([^B]*\)(?!.*TenantContext)", re.IGNORECASE), 

673 EscapePathSeverity.HIGH, 

674 "background", 

675 "BackgroundTask may lose tenant context. Use BackgroundTaskTenantContext.", 

676 ), 

677 # Session without RLS 

678 ( 

679 re.compile(r"AsyncSession\s*\(\s*\)(?!.*set_tenant)", re.IGNORECASE), 

680 EscapePathSeverity.MEDIUM, 

681 "session", 

682 "Session created without explicit RLS context setup.", 

683 ), 

684 # UNION queries (potential cross-tenant join) 

685 ( 

686 re.compile(r"\bUNION\b(?!.*tenant)", re.IGNORECASE), 

687 EscapePathSeverity.HIGH, 

688 "union", 

689 "UNION query detected. Ensure all subqueries have tenant filters.", 

690 ), 

691 # Subquery without correlation 

692 ( 

693 re.compile(r"\(\s*SELECT[^)]+FROM[^)]+\)(?!.*tenant)", re.IGNORECASE), 

694 EscapePathSeverity.MEDIUM, 

695 "subquery", 

696 "Subquery without visible tenant correlation.", 

697 ), 

698 ] 

699 

700 # Patterns that indicate SAFE tenant handling 

701 SAFE_PATTERNS: List[Pattern[str]] = [ 

702 re.compile(r"tenant_id\s*=", re.IGNORECASE), 

703 re.compile(r"\.filter\s*\([^)]*tenant", re.IGNORECASE), 

704 re.compile(r"\.where\s*\([^)]*tenant", re.IGNORECASE), 

705 re.compile(r"current_setting\s*\(['\"]app\.current_tenant", re.IGNORECASE), 

706 re.compile(r"set_tenant_context", re.IGNORECASE), 

707 re.compile(r"BackgroundTaskTenantContext", re.IGNORECASE), 

708 re.compile(r"TenantTestContext", re.IGNORECASE), 

709 ] 

710 

711 def __init__( 

712 self, 

713 tenant_column: str = "tenant_id", 

714 custom_dangerous_patterns: Optional[List[Tuple[Pattern[str], EscapePathSeverity, str, str]]] = None, 

715 custom_safe_patterns: Optional[List[Pattern[str]]] = None, 

716 ignore_patterns: Optional[List[Pattern[str]]] = None, 

717 ): 

718 """ 

719 Initialize escape path scanner. 

720 

721 Args: 

722 tenant_column: Name of tenant column to check for 

723 custom_dangerous_patterns: Additional patterns to detect 

724 custom_safe_patterns: Additional patterns that indicate safe handling 

725 ignore_patterns: Patterns to ignore (e.g., test files, comments) 

726 """ 

727 self.tenant_column = tenant_column 

728 self.dangerous_patterns = list(self.DANGEROUS_PATTERNS) 

729 self.safe_patterns = list(self.SAFE_PATTERNS) 

730 self.ignore_patterns = ignore_patterns or [] 

731 

732 if custom_dangerous_patterns: 

733 self.dangerous_patterns.extend(custom_dangerous_patterns) 

734 if custom_safe_patterns: 

735 self.safe_patterns.extend(custom_safe_patterns) 

736 

737 def scan_query(self, query: str, context: str = "unknown") -> List[EscapePathFinding]: 

738 """ 

739 Scan a query string for potential escape paths. 

740 

741 Args: 

742 query: SQL query or SQLAlchemy query string 

743 context: Context information (e.g., file:line) 

744 

745 Returns: 

746 List of findings 

747 """ 

748 findings: List[EscapePathFinding] = [] 

749 

750 # Check if query is explicitly safe 

751 for safe_pattern in self.safe_patterns: 

752 if safe_pattern.search(query): 

753 return [] # Query appears to have proper tenant handling 

754 

755 # Check for dangerous patterns 

756 for pattern, severity, category, description in self.dangerous_patterns: 

757 if pattern.search(query): 

758 finding = EscapePathFinding( 

759 severity=severity, 

760 category=category, 

761 description=description, 

762 location=context, 

763 remediation=f"Add {self.tenant_column} filter to query", 

764 compliance_impact=["SOC2 CC6.1", "ISO27001 A.9.4", "NIST AC-4"], 

765 ) 

766 findings.append(finding) 

767 

768 return findings 

769 

770 def scan_file(self, file_path: str) -> List[EscapePathFinding]: 

771 """ 

772 Scan a Python file for potential escape paths. 

773 

774 Args: 

775 file_path: Path to Python file 

776 

777 Returns: 

778 List of findings with line numbers 

779 """ 

780 findings: List[EscapePathFinding] = [] 

781 

782 try: 

783 with open(file_path, "r", encoding="utf-8") as f: 

784 content = f.read() 

785 except Exception as e: 

786 logger.warning(f"Could not read file {file_path}: {e}") 

787 return findings 

788 

789 # Skip ignored files 

790 for ignore_pattern in self.ignore_patterns: 

791 if ignore_pattern.search(file_path): 

792 return findings 

793 

794 # Scan line by line for context 

795 lines = content.split("\n") 

796 for line_num, line in enumerate(lines, 1): 

797 # Skip comments 

798 stripped = line.strip() 

799 if stripped.startswith("#") or stripped.startswith('"""') or stripped.startswith("'''"): 

800 continue 

801 

802 line_findings = self.scan_query(line, f"{file_path}:{line_num}") 

803 findings.extend(line_findings) 

804 

805 return findings 

806 

807 def scan_directory( 

808 self, 

809 directory: str, 

810 file_patterns: Optional[List[str]] = None, 

811 exclude_patterns: Optional[List[str]] = None, 

812 ) -> List[EscapePathFinding]: 

813 """ 

814 Scan a directory recursively for potential escape paths. 

815 

816 Args: 

817 directory: Root directory to scan 

818 file_patterns: Glob patterns for files to include (default: ["*.py"]) 

819 exclude_patterns: Glob patterns for files to exclude 

820 

821 Returns: 

822 List of all findings across all files 

823 """ 

824 import glob 

825 import os 

826 

827 findings: List[EscapePathFinding] = [] 

828 file_patterns = file_patterns or ["**/*.py"] 

829 exclude_patterns = exclude_patterns or ["**/test_*.py", "**/tests/**", "**/__pycache__/**"] 

830 

831 for pattern in file_patterns: 

832 full_pattern = os.path.join(directory, pattern) 

833 for file_path in glob.glob(full_pattern, recursive=True): 

834 # Check exclusions 

835 excluded = False 

836 for exclude in exclude_patterns: 

837 if glob.fnmatch.fnmatch(file_path, exclude): 

838 excluded = True 

839 break 

840 

841 if not excluded: 

842 findings.extend(self.scan_file(file_path)) 

843 

844 return findings 

845 

846 def generate_report( 

847 self, 

848 findings: List[EscapePathFinding], 

849 format: str = "text", 

850 ) -> str: 

851 """ 

852 Generate a report from scan findings. 

853 

854 Args: 

855 findings: List of findings to report 

856 format: Output format ("text", "json", "markdown") 

857 

858 Returns: 

859 Formatted report string 

860 """ 

861 if format == "json": 

862 import json 

863 

864 return json.dumps( 

865 [ 

866 { 

867 "severity": f.severity.value, 

868 "category": f.category, 

869 "description": f.description, 

870 "location": f.location, 

871 "remediation": f.remediation, 

872 "compliance_impact": f.compliance_impact, 

873 } 

874 for f in findings 

875 ], 

876 indent=2, 

877 ) 

878 

879 elif format == "markdown": 

880 lines = ["# Tenant Isolation Escape Path Report\n"] 

881 

882 # Group by severity 

883 by_severity: Dict[EscapePathSeverity, List[EscapePathFinding]] = {} 

884 for f in findings: 

885 by_severity.setdefault(f.severity, []).append(f) 

886 

887 for severity in [ 

888 EscapePathSeverity.CRITICAL, 

889 EscapePathSeverity.HIGH, 

890 EscapePathSeverity.MEDIUM, 

891 EscapePathSeverity.LOW, 

892 ]: 

893 if severity in by_severity: 

894 lines.append(f"\n## {severity.value.upper()} ({len(by_severity[severity])})\n") 

895 for f in by_severity[severity]: 

896 lines.append(f"### {f.category}\n") 

897 lines.append(f"- **Location**: `{f.location}`\n") 

898 lines.append(f"- **Description**: {f.description}\n") 

899 lines.append(f"- **Remediation**: {f.remediation}\n") 

900 lines.append(f"- **Compliance**: {', '.join(f.compliance_impact)}\n") 

901 

902 return "\n".join(lines) 

903 

904 else: # text format 

905 lines = ["=" * 60, "TENANT ISOLATION ESCAPE PATH REPORT", "=" * 60, ""] 

906 

907 critical = [f for f in findings if f.severity == EscapePathSeverity.CRITICAL] 

908 high = [f for f in findings if f.severity == EscapePathSeverity.HIGH] 

909 medium = [f for f in findings if f.severity == EscapePathSeverity.MEDIUM] 

910 low = [f for f in findings if f.severity == EscapePathSeverity.LOW] 

911 

912 lines.append(f"CRITICAL: {len(critical)} HIGH: {len(high)} MEDIUM: {len(medium)} LOW: {len(low)}") 

913 lines.append("") 

914 

915 for f in findings: 

916 lines.append(str(f)) 

917 lines.append("-" * 40) 

918 

919 return "\n".join(lines) 

920 

921 

922# ============================================================================= 

923# CI/CD Integration Utilities 

924# ============================================================================= 

925 

926 

927def ci_fail_on_findings( 

928 findings: List[EscapePathFinding], 

929 fail_on: Set[EscapePathSeverity] = None, 

930) -> int: 

931 """ 

932 Return exit code for CI/CD based on findings. 

933 

934 Args: 

935 findings: List of scan findings 

936 fail_on: Set of severities that should cause failure 

937 (default: CRITICAL and HIGH) 

938 

939 Returns: 

940 0 if no critical findings, 1 otherwise (for sys.exit()) 

941 

942 Example: 

943 scanner = TenantEscapePathScanner() 

944 findings = scanner.scan_directory("./src") 

945 sys.exit(ci_fail_on_findings(findings)) 

946 """ 

947 if fail_on is None: 

948 fail_on = {EscapePathSeverity.CRITICAL, EscapePathSeverity.HIGH} 

949 

950 failing_findings = [f for f in findings if f.severity in fail_on] 

951 

952 if failing_findings: 

953 print(f"CI FAILED: Found {len(failing_findings)} critical/high severity findings") 

954 for f in failing_findings: 

955 print(f" - {f.severity.value.upper()}: {f.description} at {f.location}") 

956 return 1 

957 

958 print(f"CI PASSED: No critical/high findings ({len(findings)} total findings)") 

959 return 0 

960 

961 

962# ============================================================================= 

963# Pytest Fixtures and Markers 

964# ============================================================================= 

965 

966 

967def pytest_configure(config: Any) -> None: 

968 """ 

969 Register pytest markers for tenant isolation tests. 

970 

971 Add to conftest.py: 

972 from netrun.rbac.testing import pytest_configure 

973 """ 

974 config.addinivalue_line( 

975 "markers", 

976 "tenant_isolation: Mark test as a tenant isolation contract test", 

977 ) 

978 config.addinivalue_line( 

979 "markers", 

980 "escape_path: Mark test as testing a specific escape path scenario", 

981 ) 

982 

983 

984def tenant_isolation_test(func: AsyncFunc) -> AsyncFunc: 

985 """ 

986 Decorator to mark a test as a tenant isolation contract test. 

987 

988 Adds additional validation and logging around the test. 

989 

990 Example: 

991 @tenant_isolation_test 

992 async def test_cross_tenant_read_impossible(self, db_session): 

993 ... 

994 """ 

995 

996 @functools.wraps(func) 

997 async def wrapper(*args: Any, **kwargs: Any) -> Any: 

998 logger.info(f"Starting tenant isolation test: {func.__name__}") 

999 try: 

1000 result = await func(*args, **kwargs) 

1001 logger.info(f"PASSED: {func.__name__}") 

1002 return result 

1003 except TenantIsolationError as e: 

1004 logger.error(f"FAILED (Isolation Error): {func.__name__}: {e}") 

1005 raise 

1006 except AssertionError as e: 

1007 logger.error(f"FAILED (Assertion): {func.__name__}: {e}") 

1008 raise 

1009 except Exception as e: 

1010 logger.error(f"FAILED (Exception): {func.__name__}: {e}") 

1011 raise 

1012 

1013 return wrapper # type: ignore 

1014 

1015 

1016# ============================================================================= 

1017# Compliance Documentation 

1018# ============================================================================= 

1019 

1020 

1021COMPLIANCE_MAPPING = { 

1022 "SOC2": { 

1023 "CC6.1": "Logical and Physical Access Controls", 

1024 "CC6.2": "Role-Based Access Control", 

1025 "CC6.3": "Segregation of Duties", 

1026 }, 

1027 "ISO27001": { 

1028 "A.9.1": "Business Requirements of Access Control", 

1029 "A.9.4": "System and Application Access Control", 

1030 "A.9.4.1": "Information Access Restriction", 

1031 }, 

1032 "NIST": { 

1033 "AC-4": "Information Flow Enforcement", 

1034 "AC-5": "Separation of Duties", 

1035 "AC-6": "Least Privilege", 

1036 }, 

1037} 

1038 

1039 

1040def get_compliance_documentation() -> str: 

1041 """ 

1042 Get documentation of compliance controls addressed by tenant isolation testing. 

1043 

1044 Returns: 

1045 Formatted compliance documentation 

1046 """ 

1047 lines = [ 

1048 "# Tenant Isolation Testing - Compliance Mapping", 

1049 "", 

1050 "The tenant isolation testing utilities in this module address the following", 

1051 "compliance requirements:", 

1052 "", 

1053 ] 

1054 

1055 for framework, controls in COMPLIANCE_MAPPING.items(): 

1056 lines.append(f"## {framework}") 

1057 for control_id, description in controls.items(): 

1058 lines.append(f"- **{control_id}**: {description}") 

1059 lines.append("") 

1060 

1061 lines.extend( 

1062 [ 

1063 "## Testing Requirements", 

1064 "", 

1065 "To maintain compliance, the following tests MUST pass before any release:", 

1066 "", 

1067 "1. **test_cross_tenant_read_impossible** - Proves Tenant B cannot read Tenant A's data", 

1068 "2. **test_cross_tenant_write_impossible** - Proves Tenant B cannot modify Tenant A's data", 

1069 "3. **test_query_without_tenant_filter_fails** - Ensures queries are validated", 

1070 "4. **test_pagination_includes_tenant_filter** - Prevents paginated data leaks", 

1071 "5. **test_background_task_preserves_tenant** - Ensures async context is maintained", 

1072 "", 

1073 "## CI/CD Integration", 

1074 "", 

1075 "Run escape path scanning as part of your CI pipeline:", 

1076 "", 

1077 "```python", 

1078 "from netrun.rbac.testing import TenantEscapePathScanner, ci_fail_on_findings", 

1079 "", 

1080 "scanner = TenantEscapePathScanner()", 

1081 "findings = scanner.scan_directory('./src')", 

1082 "sys.exit(ci_fail_on_findings(findings))", 

1083 "```", 

1084 ] 

1085 ) 

1086 

1087 return "\n".join(lines) 

1088 

1089 

1090# ============================================================================= 

1091# Module Exports 

1092# ============================================================================= 

1093 

1094__all__ = [ 

1095 # Exceptions (re-exported for convenience) 

1096 "TenantIsolationError", 

1097 # Core assertions 

1098 "assert_tenant_isolation", 

1099 "assert_tenant_isolation_sync", 

1100 # Test context 

1101 "TenantTestContext", 

1102 "tenant_test_context", 

1103 # Background task handling 

1104 "BackgroundTaskTenantContext", 

1105 "preserve_tenant_context", 

1106 # Escape path detection 

1107 "TenantEscapePathScanner", 

1108 "EscapePathSeverity", 

1109 "EscapePathFinding", 

1110 # CI/CD utilities 

1111 "ci_fail_on_findings", 

1112 # Pytest integration 

1113 "pytest_configure", 

1114 "tenant_isolation_test", 

1115 # Compliance 

1116 "get_compliance_documentation", 

1117 "COMPLIANCE_MAPPING", 

1118]