Coverage for netrun / rbac / tenant.py: 42%
36 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-18 22:20 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-18 22:20 +0000
1"""
2Tenant Context Management for Multi-Tenant RBAC
4Extracted from: Intirkast middleware/tenant_context.py + app/core/database.py
5Provides PostgreSQL Row-Level Security (RLS) session variable management
6"""
8import logging
9from typing import Optional
10from uuid import UUID
12from sqlalchemy import text
13from sqlalchemy.ext.asyncio import AsyncSession
15from .exceptions import MissingTenantContextError
17logger = logging.getLogger(__name__)
20class TenantContext:
21 """
22 Tenant context for database session scoping
24 Stores tenant_id and user_id for PostgreSQL RLS enforcement
25 """
27 def __init__(self, tenant_id: str | UUID, user_id: Optional[str | UUID] = None):
28 """
29 Initialize tenant context
31 Args:
32 tenant_id: Tenant UUID (string or UUID object)
33 user_id: User UUID for audit logging (optional)
34 """
35 self.tenant_id = str(tenant_id)
36 self.user_id = str(user_id) if user_id else None
38 def to_dict(self) -> dict:
39 """
40 Convert to dictionary for storage
42 Returns:
43 Dictionary with tenant_id and user_id
44 """
45 return {"tenant_id": self.tenant_id, "user_id": self.user_id}
48async def set_tenant_context(
49 session: AsyncSession, tenant_id: str | UUID, user_id: Optional[str | UUID] = None
50) -> None:
51 """
52 Set PostgreSQL session variables for Row-Level Security (RLS)
54 Extracted from: Intirkast test_rls_isolation.py (set_rls_context)
56 Sets:
57 - app.current_tenant_id: Used by RLS policies to filter queries
58 - app.current_user_id: Used for audit logging (optional)
60 Args:
61 session: SQLAlchemy AsyncSession
62 tenant_id: Tenant UUID to scope queries to
63 user_id: User UUID for audit logging (optional)
65 Example RLS Policy:
66 CREATE POLICY tenant_isolation_policy ON users
67 FOR ALL
68 USING (tenant_id = NULLIF(current_setting('app.current_tenant_id', true), '')::UUID);
70 Usage:
71 async with AsyncSessionLocal() as session:
72 # Set tenant context
73 await set_tenant_context(session, tenant_id="550e8400-...", user_id="660e8400-...")
75 # All queries now automatically filtered by tenant_id
76 result = await session.execute(select(User))
77 users = result.scalars().all() # Only returns users from specified tenant
79 PostgreSQL Session Variables:
80 - SET LOCAL: Variables persist for current transaction only
81 - current_setting('app.current_tenant_id', true): Retrieves variable (true = no error if missing)
82 - NULLIF(..., ''): Converts empty string to NULL (handles missing variable)
83 - ::UUID: Casts string to UUID type
84 """
85 if not tenant_id:
86 raise MissingTenantContextError("tenant_id is required for RLS enforcement")
88 # Set tenant context
89 await session.execute(
90 text("SET LOCAL app.current_tenant_id = :tenant_id"), {"tenant_id": str(tenant_id)}
91 )
93 logger.debug(f"RLS enabled for tenant: {tenant_id}")
95 # Set user context for audit logging (optional)
96 if user_id:
97 await session.execute(
98 text("SET LOCAL app.current_user_id = :user_id"), {"user_id": str(user_id)}
99 )
100 logger.debug(f"Audit context set for user: {user_id}")
103async def clear_tenant_context(session: AsyncSession) -> None:
104 """
105 Clear PostgreSQL RLS session variables
107 Extracted from: Intirkast test_rls_isolation.py (clear_rls_context)
109 Resets:
110 - app.current_tenant_id
111 - app.current_user_id
113 Usage:
114 async with AsyncSessionLocal() as session:
115 # Set tenant context
116 await set_tenant_context(session, tenant_id="550e8400-...")
118 # Query tenant-scoped data
119 result = await session.execute(select(User))
120 users = result.scalars().all()
122 # Clear context (e.g., for admin operations)
123 await clear_tenant_context(session)
125 # Query all data (no RLS filtering)
126 result = await session.execute(select(User))
127 all_users = result.scalars().all()
128 """
129 await session.execute(text("RESET app.current_tenant_id"))
130 await session.execute(text("RESET app.current_user_id"))
131 logger.debug("Tenant context cleared")
134async def get_current_tenant_id(session: AsyncSession) -> Optional[str]:
135 """
136 Retrieve current tenant_id from PostgreSQL session variable
138 Returns:
139 Tenant ID if set, None otherwise
141 Usage:
142 async with AsyncSessionLocal() as session:
143 tenant_id = await get_current_tenant_id(session)
144 if tenant_id:
145 print(f"Current tenant: {tenant_id}")
146 """
147 result = await session.execute(text("SELECT current_setting('app.current_tenant_id', true)"))
148 tenant_id = result.scalar()
150 # PostgreSQL returns empty string if variable not set
151 return tenant_id if tenant_id and tenant_id != "" else None
154async def get_current_user_id(session: AsyncSession) -> Optional[str]:
155 """
156 Retrieve current user_id from PostgreSQL session variable
158 Returns:
159 User ID if set, None otherwise
161 Usage:
162 async with AsyncSessionLocal() as session:
163 user_id = await get_current_user_id(session)
164 if user_id:
165 print(f"Current user: {user_id}")
166 """
167 result = await session.execute(text("SELECT current_setting('app.current_user_id', true)"))
168 user_id = result.scalar()
170 # PostgreSQL returns empty string if variable not set
171 return user_id if user_id and user_id != "" else None
174def get_db_with_rls(tenant_id_getter: callable, user_id_getter: Optional[callable] = None):
175 """
176 FastAPI dependency factory to get database session with RLS enabled
178 Extracted from: Intirkast app/core/database.py (get_db_with_rls)
180 Args:
181 tenant_id_getter: Function to extract tenant_id from request
182 user_id_getter: Function to extract user_id from request (optional)
184 Returns:
185 FastAPI dependency function
187 Usage:
188 # Define getter functions
189 def get_tenant_id_from_request(request: Request) -> str:
190 return request.state.tenant_id
192 def get_user_id_from_request(request: Request) -> str:
193 return request.state.user_id
195 # Create dependency
196 get_db_scoped = get_db_with_rls(
197 tenant_id_getter=get_tenant_id_from_request,
198 user_id_getter=get_user_id_from_request
199 )
201 # Use in route
202 @router.get("/api/users")
203 async def list_users(db: AsyncSession = Depends(get_db_scoped)):
204 # All queries automatically scoped to tenant
205 result = await db.execute(select(User))
206 return result.scalars().all()
208 PLACEHOLDER Pattern:
209 Replace {{AsyncSessionLocal}} with your session factory:
211 from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
212 engine = create_async_engine("{{DATABASE_URL}}")
213 AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession)
214 """
216 async def dependency(
217 tenant_id: str = tenant_id_getter, user_id: Optional[str] = user_id_getter or None
218 ):
219 """
220 Database session dependency with RLS context
222 Args:
223 tenant_id: Tenant ID from request
224 user_id: User ID from request (optional)
226 Yields:
227 AsyncSession with RLS context set
228 """
229 # PLACEHOLDER: Replace with your AsyncSessionLocal
230 # from your_app.database import AsyncSessionLocal
232 # Temporary placeholder error
233 raise NotImplementedError(
234 "Replace {{AsyncSessionLocal}} placeholder with your session factory. "
235 "See netrun_rbac.tenant.get_db_with_rls docstring for example."
236 )
238 # Example implementation (uncomment and replace):
239 # async with AsyncSessionLocal() as session:
240 # try:
241 # # Set RLS context
242 # await set_tenant_context(session, tenant_id, user_id)
243 #
244 # yield session
245 #
246 # await session.commit()
247 # except Exception as e:
248 # await session.rollback()
249 # logger.error(f"Database error: {e}")
250 # raise
251 # finally:
252 # await session.close()
254 return dependency