Coverage for dao/base_dao.py: 36.52%

150 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2026-04-15 12:24 +0000

1""" 

2Base Data Access Object for MissingTable. 

3 

4Provides common functionality and connection management for all DAOs. 

5All domain-specific DAOs should inherit from BaseDAO. 

6 

7## Caching 

8 

9Two decorators provide clean caching for DAO methods: 

10 

11### @dao_cache - Cache read operations 

12 

13 @dao_cache("teams:all") 

14 def get_all_teams(self): 

15 return self.client.table("teams").select("*").execute().data 

16 

17 @dao_cache("teams:club:{club_id}") 

18 def get_club_teams(self, club_id: int): 

19 return self.client.table("teams").eq("club_id", club_id).execute().data 

20 

21The decorator: 

221. Builds cache key from pattern (substituting {arg_name} with actual values) 

232. Checks Redis - returns cached data if found 

243. On miss - runs method, caches result, returns it 

25 

26### @invalidates_cache - Clear cache after writes 

27 

28 @invalidates_cache("mt:dao:teams:*") 

29 def add_team(self, name: str, city: str): 

30 return self.client.table("teams").insert({...}).execute() 

31 

32After successful completion, clears all keys matching the pattern(s). 

33""" 

34 

35import functools 

36import inspect 

37import json 

38import os 

39from typing import TYPE_CHECKING 

40 

41import structlog 

42 

43if TYPE_CHECKING: 

44 from supabase import Client 

45 

46logger = structlog.get_logger() 

47 

48# Shared Redis client for all DAOs 

49_redis_client = None 

50 

51 

52def get_redis_client(): 

53 """Get sync Redis client for DAO-level caching. 

54 

55 Returns a shared Redis client instance, or None if caching is disabled 

56 or Redis is unavailable. Gracefully degrades - app continues without cache. 

57 """ 

58 global _redis_client 

59 if _redis_client is not None: 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true

60 return _redis_client 

61 

62 if os.getenv("CACHE_ENABLED", "false").lower() != "true": 62 ↛ 65line 62 didn't jump to line 65 because the condition on line 62 was always true

63 return None 

64 

65 try: 

66 import redis 

67 

68 url = os.getenv("REDIS_URL", "redis://localhost:6379/0") 

69 _redis_client = redis.from_url(url, decode_responses=True) 

70 _redis_client.ping() 

71 return _redis_client 

72 except Exception as e: 

73 logger.warning("dao_redis_connection_failed", error=str(e)) 

74 return None 

75 

76 

77def clear_cache(pattern: str) -> int: 

78 """Clear cache entries matching a pattern. 

79 

80 Args: 

81 pattern: Redis key pattern (e.g., "mt:dao:clubs:*") 

82 

83 Returns: 

84 Number of keys deleted 

85 """ 

86 redis_client = get_redis_client() 

87 if not redis_client: 

88 return 0 

89 try: 

90 cursor = 0 

91 deleted = 0 

92 while True: 

93 cursor, keys = redis_client.scan(cursor, match=pattern, count=100) 

94 if keys: 

95 deleted += redis_client.delete(*keys) 

96 if cursor == 0: 

97 break 

98 if deleted > 0: 

99 logger.info("dao_cache_cleared", pattern=pattern, deleted=deleted) 

100 return deleted 

101 except Exception as e: 

102 logger.warning("dao_cache_clear_error", pattern=pattern, error=str(e)) 

103 return 0 

104 

105 

106def cache_get(key: str): 

107 """Get a value from cache. 

108 

109 Args: 

110 key: Cache key 

111 

112 Returns: 

113 Deserialized value or None if not found/error 

114 """ 

115 redis_client = get_redis_client() 

116 if not redis_client: 116 ↛ 118line 116 didn't jump to line 118 because the condition on line 116 was always true

117 return None 

118 try: 

119 cached = redis_client.get(key) 

120 if cached: 

121 logger.info("dao_cache_hit", key=key) 

122 return json.loads(cached) 

123 except Exception as e: 

124 logger.warning("dao_cache_get_error", key=key, error=str(e)) 

125 return None 

126 

127 

128def cache_set(key: str, value, ttl: int = 86400) -> bool: 

129 """Set a value in cache. 

130 

131 Args: 

132 key: Cache key 

133 value: Value to cache (will be JSON serialized) 

134 ttl: Time to live in seconds (default 24 hours) 

135 

136 Returns: 

137 True if successful, False otherwise 

138 """ 

139 redis_client = get_redis_client() 

140 if not redis_client: 140 ↛ 142line 140 didn't jump to line 142 because the condition on line 140 was always true

141 return False 

142 try: 

143 redis_client.setex(key, ttl, json.dumps(value, default=str)) 

144 logger.info("dao_cache_set", key=key) 

145 return True 

146 except Exception as e: 

147 logger.warning("dao_cache_set_error", key=key, error=str(e)) 

148 return False 

149 

150 

151# ============================================================================= 

152# CACHING DECORATORS 

153# ============================================================================= 

154 

155 

156def dao_cache(key_pattern: str, ttl: int = 86400): 

157 """Decorator to cache DAO method results. 

158 

159 Args: 

160 key_pattern: Cache key pattern with optional {arg} placeholders. 

161 e.g., "teams:all" or "teams:club:{club_id}" 

162 ttl: Time to live in seconds (default 24 hours) 

163 

164 Example: 

165 @dao_cache("teams:club:{club_id}") 

166 def get_club_teams(self, club_id: int): 

167 return self.client.table("teams").eq("club_id", club_id).execute().data 

168 

169 How it works: 

170 1. When get_club_teams(club_id=5) is called 

171 2. Decorator builds key: "mt:dao:teams:club:5" 

172 3. Checks Redis for cached data 

173 4. If found: returns cached data (method never runs) 

174 5. If not found: runs method, caches result, returns it 

175 """ 

176 

177 def decorator(func): 

178 @functools.wraps(func) 

179 def wrapper(self, *args, **kwargs): 

180 # Build cache key by substituting {arg_name} with actual values 

181 # Get parameter names from function signature (skip 'self') 

182 sig = inspect.signature(func) 

183 param_names = list(sig.parameters.keys())[1:] 

184 

185 # Map positional args to their names 

186 key_values = dict(zip(param_names, args, strict=False)) 

187 key_values.update(kwargs) 

188 

189 # Build full cache key: mt:dao:{pattern with substitutions} 

190 try: 

191 cache_key = f"mt:dao:{key_pattern.format(**key_values)}" 

192 except KeyError as e: 

193 # If pattern has placeholder not in args, skip caching 

194 logger.warning("dao_cache_key_error", pattern=key_pattern, missing=str(e)) 

195 return func(self, *args, **kwargs) 

196 

197 # Try to get from cache 

198 cached = cache_get(cache_key) 

199 if cached is not None: 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true

200 return cached 

201 

202 # Cache miss - run the actual method 

203 logger.debug("dao_cache_miss", key=cache_key, method=func.__name__) 

204 result = func(self, *args, **kwargs) 

205 

206 # Cache the result (only if not None) 

207 if result is not None: 207 ↛ 210line 207 didn't jump to line 210 because the condition on line 207 was always true

208 cache_set(cache_key, result, ttl) 

209 

210 return result 

211 

212 return wrapper 

213 

214 return decorator 

215 

216 

217def invalidates_cache(*patterns: str): 

218 """Decorator to clear cache after a write operation succeeds. 

219 

220 Args: 

221 *patterns: One or more cache key patterns to clear. 

222 e.g., "mt:dao:teams:*", "mt:dao:clubs:*" 

223 

224 Example: 

225 @invalidates_cache("mt:dao:teams:*") 

226 def add_team(self, name: str, city: str): 

227 return self.client.table("teams").insert({...}).execute() 

228 

229 @invalidates_cache("mt:dao:teams:*", "mt:dao:clubs:*") 

230 def update_team_club(self, team_id: int, club_id: int): 

231 ... 

232 

233 How it works: 

234 1. Runs the wrapped method 

235 2. If successful (no exception), clears all keys matching each pattern 

236 3. Returns the method's result 

237 """ 

238 

239 def decorator(func): 

240 @functools.wraps(func) 

241 def wrapper(self, *args, **kwargs): 

242 # Run the actual method first 

243 result = func(self, *args, **kwargs) 

244 

245 # On success, invalidate cache patterns 

246 for pattern in patterns: 

247 clear_cache(pattern) 

248 

249 return result 

250 

251 return wrapper 

252 

253 return decorator 

254 

255 

256class BaseDAO: 

257 """Base DAO with shared connection logic and common utilities.""" 

258 

259 def __init__(self, connection_holder): 

260 """ 

261 Initialize with a SupabaseConnection. 

262 

263 Args: 

264 connection_holder: SupabaseConnection instance 

265 

266 Raises: 

267 TypeError: If connection_holder is not a SupabaseConnection 

268 """ 

269 from dao.match_dao import SupabaseConnection 

270 

271 if not isinstance(connection_holder, SupabaseConnection): 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true

272 raise TypeError("connection_holder must be a SupabaseConnection instance") 

273 

274 self.connection_holder = connection_holder 

275 self.client: Client = connection_holder.get_client() 

276 

277 def execute_query(self, query, operation_name: str = "database operation"): 

278 """ 

279 Execute a Supabase query with common error handling. 

280 

281 Args: 

282 query: Supabase query object (result of .select(), .insert(), etc.) 

283 operation_name: Description of the operation for logging 

284 

285 Returns: 

286 Query response data 

287 

288 Raises: 

289 Exception: Re-raises any database errors after logging 

290 """ 

291 try: 

292 response = query.execute() 

293 return response 

294 except Exception as e: 

295 logger.exception( 

296 f"Error during {operation_name}", 

297 operation=operation_name, 

298 error_type=type(e).__name__, 

299 error_message=str(e), 

300 ) 

301 raise 

302 

303 def safe_execute(self, query, operation_name: str = "database operation", default=None): 

304 """ 

305 Execute a query with error handling that returns a default value on error. 

306 

307 Useful for non-critical queries where you want to continue execution 

308 even if the query fails. 

309 

310 Args: 

311 query: Supabase query object 

312 operation_name: Description of the operation for logging 

313 default: Value to return if query fails (default: None) 

314 

315 Returns: 

316 Query response data or default value on error 

317 """ 

318 try: 

319 response = query.execute() 

320 return response 

321 except Exception as e: 

322 logger.warning( 

323 f"Non-critical error during {operation_name}", 

324 operation=operation_name, 

325 error_type=type(e).__name__, 

326 error_message=str(e), 

327 returning_default=default, 

328 ) 

329 return default 

330 

331 def get_by_id(self, table: str, record_id: int, id_field: str = "id") -> dict | None: 

332 """ 

333 Generic method to get a single record by ID. 

334 

335 Args: 

336 table: Table name 

337 record_id: Record ID to fetch 

338 id_field: Name of the ID field (default: "id") 

339 

340 Returns: 

341 dict: Record data or None if not found 

342 

343 Raises: 

344 Exception: If database query fails 

345 """ 

346 try: 

347 response = self.client.table(table).select("*").eq(id_field, record_id).execute() 

348 return response.data[0] if response.data else None 

349 except Exception: 

350 logger.exception( 

351 f"Error fetching record from {table}", 

352 table=table, 

353 record_id=record_id, 

354 id_field=id_field, 

355 ) 

356 raise 

357 

358 def get_all(self, table: str, order_by: str | None = None) -> list[dict]: 

359 """ 

360 Generic method to get all records from a table. 

361 

362 Args: 

363 table: Table name 

364 order_by: Optional field name to order by 

365 

366 Returns: 

367 list[dict]: List of records 

368 

369 Raises: 

370 Exception: If database query fails 

371 """ 

372 try: 

373 query = self.client.table(table).select("*") 

374 if order_by: 

375 query = query.order(order_by) 

376 response = query.execute() 

377 return response.data 

378 except Exception: 

379 logger.exception(f"Error fetching all records from {table}", table=table, order_by=order_by) 

380 raise 

381 

382 def exists(self, table: str, field: str, value) -> bool: 

383 """ 

384 Check if a record exists with the given field value. 

385 

386 Args: 

387 table: Table name 

388 field: Field name to check 

389 value: Value to match 

390 

391 Returns: 

392 bool: True if record exists, False otherwise 

393 """ 

394 try: 

395 response = self.client.table(table).select("id").eq(field, value).limit(1).execute() 

396 return len(response.data) > 0 

397 except Exception: 

398 logger.exception(f"Error checking existence in {table}", table=table, field=field, value=value) 

399 return False 

400 

401 def delete_by_id(self, table: str, record_id: int, id_field: str = "id") -> bool: 

402 """ 

403 Generic method to delete a record by ID. 

404 

405 Args: 

406 table: Table name 

407 record_id: Record ID to delete 

408 id_field: Name of the ID field (default: "id") 

409 

410 Returns: 

411 bool: True if successful, False otherwise 

412 """ 

413 try: 

414 self.client.table(table).delete().eq(id_field, record_id).execute() 

415 logger.info(f"Deleted record from {table}", table=table, record_id=record_id) 

416 return True 

417 except Exception: 

418 logger.exception(f"Error deleting record from {table}", table=table, record_id=record_id) 

419 return False