Coverage for netrun / rbac / tenant.py: 42%

36 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-18 22:20 +0000

1""" 

2Tenant Context Management for Multi-Tenant RBAC 

3 

4Extracted from: Intirkast middleware/tenant_context.py + app/core/database.py 

5Provides PostgreSQL Row-Level Security (RLS) session variable management 

6""" 

7 

8import logging 

9from typing import Optional 

10from uuid import UUID 

11 

12from sqlalchemy import text 

13from sqlalchemy.ext.asyncio import AsyncSession 

14 

15from .exceptions import MissingTenantContextError 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20class TenantContext: 

21 """ 

22 Tenant context for database session scoping 

23 

24 Stores tenant_id and user_id for PostgreSQL RLS enforcement 

25 """ 

26 

27 def __init__(self, tenant_id: str | UUID, user_id: Optional[str | UUID] = None): 

28 """ 

29 Initialize tenant context 

30 

31 Args: 

32 tenant_id: Tenant UUID (string or UUID object) 

33 user_id: User UUID for audit logging (optional) 

34 """ 

35 self.tenant_id = str(tenant_id) 

36 self.user_id = str(user_id) if user_id else None 

37 

38 def to_dict(self) -> dict: 

39 """ 

40 Convert to dictionary for storage 

41 

42 Returns: 

43 Dictionary with tenant_id and user_id 

44 """ 

45 return {"tenant_id": self.tenant_id, "user_id": self.user_id} 

46 

47 

48async def set_tenant_context( 

49 session: AsyncSession, tenant_id: str | UUID, user_id: Optional[str | UUID] = None 

50) -> None: 

51 """ 

52 Set PostgreSQL session variables for Row-Level Security (RLS) 

53 

54 Extracted from: Intirkast test_rls_isolation.py (set_rls_context) 

55 

56 Sets: 

57 - app.current_tenant_id: Used by RLS policies to filter queries 

58 - app.current_user_id: Used for audit logging (optional) 

59 

60 Args: 

61 session: SQLAlchemy AsyncSession 

62 tenant_id: Tenant UUID to scope queries to 

63 user_id: User UUID for audit logging (optional) 

64 

65 Example RLS Policy: 

66 CREATE POLICY tenant_isolation_policy ON users 

67 FOR ALL 

68 USING (tenant_id = NULLIF(current_setting('app.current_tenant_id', true), '')::UUID); 

69 

70 Usage: 

71 async with AsyncSessionLocal() as session: 

72 # Set tenant context 

73 await set_tenant_context(session, tenant_id="550e8400-...", user_id="660e8400-...") 

74 

75 # All queries now automatically filtered by tenant_id 

76 result = await session.execute(select(User)) 

77 users = result.scalars().all() # Only returns users from specified tenant 

78 

79 PostgreSQL Session Variables: 

80 - SET LOCAL: Variables persist for current transaction only 

81 - current_setting('app.current_tenant_id', true): Retrieves variable (true = no error if missing) 

82 - NULLIF(..., ''): Converts empty string to NULL (handles missing variable) 

83 - ::UUID: Casts string to UUID type 

84 """ 

85 if not tenant_id: 

86 raise MissingTenantContextError("tenant_id is required for RLS enforcement") 

87 

88 # Set tenant context 

89 await session.execute( 

90 text("SET LOCAL app.current_tenant_id = :tenant_id"), {"tenant_id": str(tenant_id)} 

91 ) 

92 

93 logger.debug(f"RLS enabled for tenant: {tenant_id}") 

94 

95 # Set user context for audit logging (optional) 

96 if user_id: 

97 await session.execute( 

98 text("SET LOCAL app.current_user_id = :user_id"), {"user_id": str(user_id)} 

99 ) 

100 logger.debug(f"Audit context set for user: {user_id}") 

101 

102 

103async def clear_tenant_context(session: AsyncSession) -> None: 

104 """ 

105 Clear PostgreSQL RLS session variables 

106 

107 Extracted from: Intirkast test_rls_isolation.py (clear_rls_context) 

108 

109 Resets: 

110 - app.current_tenant_id 

111 - app.current_user_id 

112 

113 Usage: 

114 async with AsyncSessionLocal() as session: 

115 # Set tenant context 

116 await set_tenant_context(session, tenant_id="550e8400-...") 

117 

118 # Query tenant-scoped data 

119 result = await session.execute(select(User)) 

120 users = result.scalars().all() 

121 

122 # Clear context (e.g., for admin operations) 

123 await clear_tenant_context(session) 

124 

125 # Query all data (no RLS filtering) 

126 result = await session.execute(select(User)) 

127 all_users = result.scalars().all() 

128 """ 

129 await session.execute(text("RESET app.current_tenant_id")) 

130 await session.execute(text("RESET app.current_user_id")) 

131 logger.debug("Tenant context cleared") 

132 

133 

134async def get_current_tenant_id(session: AsyncSession) -> Optional[str]: 

135 """ 

136 Retrieve current tenant_id from PostgreSQL session variable 

137 

138 Returns: 

139 Tenant ID if set, None otherwise 

140 

141 Usage: 

142 async with AsyncSessionLocal() as session: 

143 tenant_id = await get_current_tenant_id(session) 

144 if tenant_id: 

145 print(f"Current tenant: {tenant_id}") 

146 """ 

147 result = await session.execute(text("SELECT current_setting('app.current_tenant_id', true)")) 

148 tenant_id = result.scalar() 

149 

150 # PostgreSQL returns empty string if variable not set 

151 return tenant_id if tenant_id and tenant_id != "" else None 

152 

153 

154async def get_current_user_id(session: AsyncSession) -> Optional[str]: 

155 """ 

156 Retrieve current user_id from PostgreSQL session variable 

157 

158 Returns: 

159 User ID if set, None otherwise 

160 

161 Usage: 

162 async with AsyncSessionLocal() as session: 

163 user_id = await get_current_user_id(session) 

164 if user_id: 

165 print(f"Current user: {user_id}") 

166 """ 

167 result = await session.execute(text("SELECT current_setting('app.current_user_id', true)")) 

168 user_id = result.scalar() 

169 

170 # PostgreSQL returns empty string if variable not set 

171 return user_id if user_id and user_id != "" else None 

172 

173 

174def get_db_with_rls(tenant_id_getter: callable, user_id_getter: Optional[callable] = None): 

175 """ 

176 FastAPI dependency factory to get database session with RLS enabled 

177 

178 Extracted from: Intirkast app/core/database.py (get_db_with_rls) 

179 

180 Args: 

181 tenant_id_getter: Function to extract tenant_id from request 

182 user_id_getter: Function to extract user_id from request (optional) 

183 

184 Returns: 

185 FastAPI dependency function 

186 

187 Usage: 

188 # Define getter functions 

189 def get_tenant_id_from_request(request: Request) -> str: 

190 return request.state.tenant_id 

191 

192 def get_user_id_from_request(request: Request) -> str: 

193 return request.state.user_id 

194 

195 # Create dependency 

196 get_db_scoped = get_db_with_rls( 

197 tenant_id_getter=get_tenant_id_from_request, 

198 user_id_getter=get_user_id_from_request 

199 ) 

200 

201 # Use in route 

202 @router.get("/api/users") 

203 async def list_users(db: AsyncSession = Depends(get_db_scoped)): 

204 # All queries automatically scoped to tenant 

205 result = await db.execute(select(User)) 

206 return result.scalars().all() 

207 

208 PLACEHOLDER Pattern: 

209 Replace {{AsyncSessionLocal}} with your session factory: 

210 

211 from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker 

212 engine = create_async_engine("{{DATABASE_URL}}") 

213 AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession) 

214 """ 

215 

216 async def dependency( 

217 tenant_id: str = tenant_id_getter, user_id: Optional[str] = user_id_getter or None 

218 ): 

219 """ 

220 Database session dependency with RLS context 

221 

222 Args: 

223 tenant_id: Tenant ID from request 

224 user_id: User ID from request (optional) 

225 

226 Yields: 

227 AsyncSession with RLS context set 

228 """ 

229 # PLACEHOLDER: Replace with your AsyncSessionLocal 

230 # from your_app.database import AsyncSessionLocal 

231 

232 # Temporary placeholder error 

233 raise NotImplementedError( 

234 "Replace {{AsyncSessionLocal}} placeholder with your session factory. " 

235 "See netrun_rbac.tenant.get_db_with_rls docstring for example." 

236 ) 

237 

238 # Example implementation (uncomment and replace): 

239 # async with AsyncSessionLocal() as session: 

240 # try: 

241 # # Set RLS context 

242 # await set_tenant_context(session, tenant_id, user_id) 

243 # 

244 # yield session 

245 # 

246 # await session.commit() 

247 # except Exception as e: 

248 # await session.rollback() 

249 # logger.error(f"Database error: {e}") 

250 # raise 

251 # finally: 

252 # await session.close() 

253 

254 return dependency