Coverage for src / crump / database.py: 91%
642 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 14:40 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 14:40 +0000
1"""Database operations for crump."""
3from __future__ import annotations
5import datetime
6import logging
7import sqlite3
8from pathlib import Path
9from typing import Any, Protocol
11import psycopg
12from psycopg import sql
14from crump.config import ColumnMapping, CrumpJob, FailureMode, apply_row_transformations
15from crump.tabular_file import create_reader
18def _detect_file_format(file_path: Path) -> Any:
19 """Detect file format from extension for tabular files.
21 Args:
22 file_path: Path to the file
24 Returns:
25 InputFileType enum value (CSV or PARQUET only, defaults to CSV for unknown extensions)
27 Note:
28 This function only detects CSV and Parquet formats since those are the
29 formats supported by the tabular file reader. CDF files are not directly
30 syncable and must be extracted first.
31 """
32 from crump.file_types import InputFileType
34 try:
35 file_type = InputFileType.from_path(str(file_path))
36 # Only return CSV or PARQUET; treat everything else (including CDF) as CSV
37 if file_type == InputFileType.PARQUET:
38 return InputFileType.PARQUET
39 else:
40 return InputFileType.CSV
41 except ValueError:
42 # Unknown extension, default to CSV
43 return InputFileType.CSV
46logger = logging.getLogger(__name__)
49class DryRunSummary:
50 """Summary of changes that would be made during a dry-run sync."""
52 def __init__(self) -> None:
53 """Initialize dry-run summary."""
54 self.table_name: str = ""
55 self.table_exists: bool = False
56 self.new_columns: list[tuple[str, str]] = []
57 self.new_indexes: list[str] = []
58 self.rows_to_sync: int = 0
59 self.rows_to_delete: int = 0
62class DatabaseBackend(Protocol):
63 """Protocol for database backend operations."""
65 def execute(self, query: str, params: tuple[Any, ...] | None = None) -> None:
66 """Execute a query."""
67 ...
69 def fetchall(self, query: str, params: tuple[Any, ...] | None = None) -> list[tuple[Any, ...]]:
70 """Fetch all results from a query."""
71 ...
73 def commit(self) -> None:
74 """Commit the current transaction."""
75 ...
77 def close(self) -> None:
78 """Close the connection."""
79 ...
81 def map_data_type(self, data_type: str | None) -> str:
82 """Map config data type to SQL database type."""
83 ...
85 def create_table_if_not_exists(
86 self, table_name: str, columns: dict[str, str], primary_keys: list[str] | None = None
87 ) -> None:
88 """Create table if it doesn't exist."""
89 ...
91 def get_existing_columns(self, table_name: str) -> set[str]:
92 """Get set of existing column names in a table."""
93 ...
95 def add_column(self, table_name: str, column_name: str, column_type: str) -> None:
96 """Add a new column to an existing table."""
97 ...
99 def upsert_row(
100 self, table_name: str, conflict_columns: list[str], row_data: dict[str, Any]
101 ) -> None:
102 """Upsert a row into the database."""
103 ...
105 def delete_stale_records_compound(
106 self,
107 table_name: str,
108 id_columns: list[str],
109 filter_columns: dict[str, str],
110 current_ids: set[tuple],
111 ) -> int:
112 """Delete records from database that aren't in current CSV using compound filter key."""
113 ...
115 def count_stale_records_compound(
116 self,
117 table_name: str,
118 id_columns: list[str],
119 filter_columns: dict[str, str],
120 current_ids: set[tuple],
121 ) -> int:
122 """Count records that would be deleted using compound filter key."""
123 ...
125 def get_existing_indexes(self, table_name: str) -> set[str]:
126 """Get set of existing index names for a table."""
127 ...
129 def create_index(
130 self, table_name: str, index_name: str, columns: list[tuple[str, str]]
131 ) -> None:
132 """Create an index on the specified columns.
134 Args:
135 table_name: Name of the table
136 index_name: Name of the index to create
137 columns: List of (column_name, order) tuples, e.g. [('email', 'ASC'), ('date', 'DESC')]
138 """
139 ...
141 def table_exists(self, table_name: str) -> bool:
142 """Check if a table exists in the database.
144 Args:
145 table_name: Name of the table to check
147 Returns:
148 True if table exists, False otherwise
149 """
150 ...
153class PostgreSQLBackend:
154 """PostgreSQL database backend."""
156 def __init__(self, connection_string: str) -> None:
157 """Initialize PostgreSQL connection."""
158 self.conn = psycopg.connect(connection_string)
160 def execute(self, query: str, params: tuple[Any, ...] | None = None) -> None:
161 """Execute a query."""
162 with self.conn.cursor() as cur:
163 if params:
164 cur.execute(query, params)
165 else:
166 cur.execute(query)
168 def fetchall(self, query: str, params: tuple[Any, ...] | None = None) -> list[tuple[Any, ...]]:
169 """Fetch all results from a query."""
170 with self.conn.cursor() as cur:
171 if params:
172 cur.execute(query, params)
173 else:
174 cur.execute(query)
175 return cur.fetchall()
177 def commit(self) -> None:
178 """Commit the current transaction."""
179 self.conn.commit()
181 def close(self) -> None:
182 """Close the connection."""
183 self.conn.close()
185 def map_data_type(self, data_type: str | None) -> str:
186 """Map config data type to PostgreSQL type."""
187 if data_type is None:
188 return "TEXT"
190 data_type_lower = data_type.lower().strip()
192 # Check for varchar(N) pattern
193 if data_type_lower.startswith("varchar"):
194 return data_type.upper() # VARCHAR(N)
196 # Map other types
197 type_mapping = {
198 "integer": "INTEGER",
199 "int": "INTEGER",
200 "bigint": "BIGINT",
201 "float": "DOUBLE PRECISION",
202 "double": "DOUBLE PRECISION",
203 "date": "DATE",
204 "datetime": "TIMESTAMP",
205 "timestamp": "TIMESTAMP",
206 "text": "TEXT",
207 "string": "TEXT",
208 }
210 return type_mapping.get(data_type_lower, "TEXT")
212 def create_table_if_not_exists(
213 self, table_name: str, columns: dict[str, str], primary_keys: list[str] | None = None
214 ) -> None:
215 """Create table if it doesn't exist."""
216 column_defs = []
217 for col_name, col_type in columns.items():
218 column_defs.append(sql.SQL("{} {}").format(sql.Identifier(col_name), sql.SQL(col_type)))
220 # Add primary key constraint if specified
221 if primary_keys:
222 pk_constraint = sql.SQL("PRIMARY KEY ({})").format(
223 sql.SQL(", ").join(sql.Identifier(pk) for pk in primary_keys)
224 )
225 column_defs.append(pk_constraint)
227 query = sql.SQL("CREATE TABLE IF NOT EXISTS {} ({})").format(
228 sql.Identifier(table_name), sql.SQL(", ").join(column_defs)
229 )
230 self.execute(query.as_string(self.conn))
231 self.commit()
233 def get_existing_columns(self, table_name: str) -> set[str]:
234 """Get set of existing column names in a table.
236 Uses case-insensitive comparison to handle quoted identifiers that preserve case.
237 """
238 query = """
239 SELECT column_name
240 FROM information_schema.columns
241 WHERE LOWER(table_name) = LOWER(%s)
242 """
243 results = self.fetchall(query, (table_name,))
244 return {row[0].lower() for row in results}
246 def add_column(self, table_name: str, column_name: str, column_type: str) -> None:
247 """Add a new column to an existing table."""
248 query = sql.SQL("ALTER TABLE {} ADD COLUMN {} {}").format(
249 sql.Identifier(table_name),
250 sql.Identifier(column_name),
251 sql.SQL(column_type),
252 )
253 self.execute(query.as_string(self.conn))
254 self.commit()
256 def upsert_row(
257 self, table_name: str, conflict_columns: list[str], row_data: dict[str, Any]
258 ) -> None:
259 """Upsert a row into the database."""
260 columns = list(row_data.keys())
261 values = tuple(row_data.values())
263 insert_query = sql.SQL(
264 "INSERT INTO {} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}"
265 ).format(
266 sql.Identifier(table_name),
267 sql.SQL(", ").join(sql.Identifier(col) for col in columns),
268 sql.SQL(", ").join(sql.Placeholder() * len(values)),
269 sql.SQL(", ").join(sql.Identifier(col) for col in conflict_columns),
270 sql.SQL(", ").join(
271 sql.SQL("{} = EXCLUDED.{}").format(sql.Identifier(col), sql.Identifier(col))
272 for col in columns
273 if col not in conflict_columns
274 ),
275 )
276 self.execute(insert_query.as_string(self.conn), values)
277 self.commit()
279 def count_stale_records_compound(
280 self,
281 table_name: str,
282 id_columns: list[str],
283 filter_columns: dict[str, str],
284 current_ids: set[tuple],
285 ) -> int:
286 """Count records that would be deleted using compound filter key.
288 Args:
289 table_name: Name of the table
290 id_columns: List of ID column names (for compound keys)
291 filter_columns: Dictionary of column_name -> value to filter by (compound key)
292 current_ids: Set of ID tuples from the current CSV
294 Returns:
295 Count of records that would be deleted
296 """
297 if not current_ids or not filter_columns:
298 return 0
300 # Build WHERE clause: WHERE col1 = ? AND col2 = ? AND (id1, id2) NOT IN (...)
301 filter_conditions = [
302 sql.SQL("{} = %s").format(sql.Identifier(col)) for col in filter_columns
303 ]
305 if len(id_columns) == 1:
306 # Single key - simpler query
307 current_ids_list = [
308 id_val[0] if isinstance(id_val, tuple) else id_val for id_val in current_ids
309 ]
310 count_query = sql.SQL("SELECT COUNT(*) FROM {} WHERE {} AND {} NOT IN ({})").format(
311 sql.Identifier(table_name),
312 sql.SQL(" AND ").join(filter_conditions),
313 sql.Identifier(id_columns[0]),
314 sql.SQL(", ").join(sql.Placeholder() * len(current_ids_list)),
315 )
316 params = tuple(list(filter_columns.values()) + current_ids_list)
317 else:
318 # Compound key - use row value constructor
319 id_cols_sql = sql.SQL("({})").format(
320 sql.SQL(", ").join(sql.Identifier(col) for col in id_columns)
321 )
322 placeholders = sql.SQL(", ").join(
323 sql.SQL("({})").format(sql.SQL(", ").join(sql.Placeholder() * len(id_columns)))
324 for _ in current_ids
325 )
326 count_query = sql.SQL("SELECT COUNT(*) FROM {} WHERE {} AND {} NOT IN ({})").format(
327 sql.Identifier(table_name),
328 sql.SQL(" AND ").join(filter_conditions),
329 id_cols_sql,
330 placeholders,
331 )
332 # Flatten the list of tuples for params
333 id_params = [val for id_tuple in current_ids for val in id_tuple]
334 params = tuple(list(filter_columns.values()) + id_params)
336 count_result = self.fetchall(count_query.as_string(self.conn), params)
337 return count_result[0][0] if count_result else 0
339 def delete_stale_records_compound(
340 self,
341 table_name: str,
342 id_columns: list[str],
343 filter_columns: dict[str, str],
344 current_ids: set[tuple],
345 ) -> int:
346 """Delete records from database that aren't in current CSV using compound filter key.
348 Args:
349 table_name: Name of the table
350 id_columns: List of ID column names (for compound keys)
351 filter_columns: Dictionary of column_name -> value to filter by (compound key)
352 current_ids: Set of ID tuples from the current CSV
354 Returns:
355 Count of records deleted
356 """
357 if not current_ids or not filter_columns:
358 return 0
360 # Build WHERE clause: WHERE col1 = ? AND col2 = ? AND (id1, id2) NOT IN (...)
361 filter_conditions = [
362 sql.SQL("{} = %s").format(sql.Identifier(col)) for col in filter_columns
363 ]
365 if len(id_columns) == 1:
366 # Single key - simpler query
367 current_ids_list = [
368 id_val[0] if isinstance(id_val, tuple) else id_val for id_val in current_ids
369 ]
370 count_query = sql.SQL("SELECT COUNT(*) FROM {} WHERE {} AND {} NOT IN ({})").format(
371 sql.Identifier(table_name),
372 sql.SQL(" AND ").join(filter_conditions),
373 sql.Identifier(id_columns[0]),
374 sql.SQL(", ").join(sql.Placeholder() * len(current_ids_list)),
375 )
376 delete_query = sql.SQL("DELETE FROM {} WHERE {} AND {} NOT IN ({})").format(
377 sql.Identifier(table_name),
378 sql.SQL(" AND ").join(filter_conditions),
379 sql.Identifier(id_columns[0]),
380 sql.SQL(", ").join(sql.Placeholder() * len(current_ids_list)),
381 )
382 params = tuple(list(filter_columns.values()) + current_ids_list)
383 else:
384 # Compound key - use row value constructor
385 id_cols_sql = sql.SQL("({})").format(
386 sql.SQL(", ").join(sql.Identifier(col) for col in id_columns)
387 )
388 placeholders = sql.SQL(", ").join(
389 sql.SQL("({})").format(sql.SQL(", ").join(sql.Placeholder() * len(id_columns)))
390 for _ in current_ids
391 )
392 count_query = sql.SQL("SELECT COUNT(*) FROM {} WHERE {} AND {} NOT IN ({})").format(
393 sql.Identifier(table_name),
394 sql.SQL(" AND ").join(filter_conditions),
395 id_cols_sql,
396 placeholders,
397 )
398 delete_query = sql.SQL("DELETE FROM {} WHERE {} AND {} NOT IN ({})").format(
399 sql.Identifier(table_name),
400 sql.SQL(" AND ").join(filter_conditions),
401 id_cols_sql,
402 placeholders,
403 )
404 # Flatten the list of tuples for params
405 id_params = [val for id_tuple in current_ids for val in id_tuple]
406 params = tuple(list(filter_columns.values()) + id_params)
408 # Count first
409 count_sql = count_query.as_string(self.conn)
410 logger.debug(f"PostgreSQL count query: {count_sql}")
411 logger.debug(f"PostgreSQL count params: {params}")
412 count_result = self.fetchall(count_sql, params)
413 deleted_count = count_result[0][0] if count_result else 0
415 # Then delete
416 delete_sql = delete_query.as_string(self.conn)
417 logger.debug(f"PostgreSQL delete query: {delete_sql}")
418 logger.debug(f"PostgreSQL delete params: {params}")
419 logger.debug(f"PostgreSQL deleted count: {deleted_count}")
420 self.execute(delete_sql, params)
421 self.commit()
423 return deleted_count
425 def get_existing_indexes(self, table_name: str) -> set[str]:
426 """Get set of existing index names for a table.
428 Uses case-insensitive comparison to handle quoted identifiers that preserve case.
429 """
430 query = """
431 SELECT indexname
432 FROM pg_indexes
433 WHERE LOWER(tablename) = LOWER(%s)
434 """
435 results = self.fetchall(query, (table_name,))
436 return {row[0].lower() for row in results}
438 def create_index(
439 self, table_name: str, index_name: str, columns: list[tuple[str, str]]
440 ) -> None:
441 """Create an index on the specified columns."""
442 # Build column list with order
443 column_parts = []
444 for col_name, order in columns:
445 column_parts.append(sql.SQL("{} {}").format(sql.Identifier(col_name), sql.SQL(order)))
447 query = sql.SQL("CREATE INDEX IF NOT EXISTS {} ON {} ({})").format(
448 sql.Identifier(index_name),
449 sql.Identifier(table_name),
450 sql.SQL(", ").join(column_parts),
451 )
453 self.execute(query.as_string(self.conn))
454 self.commit()
456 def table_exists(self, table_name: str) -> bool:
457 """Check if a table exists in the database.
459 Uses case-insensitive comparison to handle quoted identifiers that preserve case.
461 Args:
462 table_name: Name of the table to check
464 Returns:
465 True if table exists, False otherwise
466 """
467 query = """
468 SELECT EXISTS (
469 SELECT FROM information_schema.tables
470 WHERE LOWER(table_name) = LOWER(%s)
471 )
472 """
473 result = self.fetchall(query, (table_name,))
474 return result[0][0] if result else False
477class SQLiteBackend:
478 """SQLite database backend."""
480 def __init__(self, connection_string: str) -> None:
481 """Initialize SQLite connection."""
482 # Extract database path from connection string
483 # Supports: sqlite:///path/to/db.db or sqlite:///:memory:
484 if connection_string.startswith("sqlite:///"):
485 db_path = connection_string[10:] # Remove 'sqlite:///'
486 elif connection_string.startswith("sqlite://"):
487 db_path = connection_string[9:] # Remove 'sqlite://'
488 else:
489 db_path = connection_string
491 self.conn = sqlite3.connect(db_path)
492 self.cursor = self.conn.cursor()
494 def execute(self, query: str, params: tuple[Any, ...] | None = None) -> None:
495 """Execute a query."""
496 if params:
497 self.cursor.execute(query, params)
498 else:
499 self.cursor.execute(query)
501 def fetchall(self, query: str, params: tuple[Any, ...] | None = None) -> list[tuple[Any, ...]]:
502 """Fetch all results from a query."""
503 if params:
504 self.cursor.execute(query, params)
505 else:
506 self.cursor.execute(query)
507 return self.cursor.fetchall()
509 def commit(self) -> None:
510 """Commit the current transaction."""
511 self.conn.commit()
513 def close(self) -> None:
514 """Close the connection."""
515 self.cursor.close()
516 self.conn.close()
518 def map_data_type(self, data_type: str | None) -> str:
519 """Map config data type to SQLite type."""
520 if data_type is None:
521 return "TEXT"
523 data_type_lower = data_type.lower().strip()
525 # SQLite doesn't have VARCHAR, use TEXT
526 if data_type_lower.startswith("varchar"):
527 return "TEXT"
529 # Map other types
530 type_mapping = {
531 "integer": "INTEGER",
532 "int": "INTEGER",
533 "bigint": "INTEGER", # SQLite INTEGER is 8-byte signed, equivalent to BIGINT
534 "float": "REAL",
535 "double": "REAL",
536 "date": "TEXT",
537 "datetime": "TEXT",
538 "timestamp": "TEXT",
539 "text": "TEXT",
540 "string": "TEXT",
541 }
543 return type_mapping.get(data_type_lower, "TEXT")
545 def create_table_if_not_exists(
546 self, table_name: str, columns: dict[str, str], primary_keys: list[str] | None = None
547 ) -> None:
548 """Create table if it doesn't exist."""
549 column_defs_str = ", ".join(
550 f'"{col_name}" {col_type}' for col_name, col_type in columns.items()
551 )
553 # Add primary key constraint if specified
554 if primary_keys:
555 pk_columns = ", ".join(f'"{pk}"' for pk in primary_keys)
556 column_defs_str += f", PRIMARY KEY ({pk_columns})"
558 query = f'CREATE TABLE IF NOT EXISTS "{table_name}" ({column_defs_str})'
559 self.execute(query)
560 self.commit()
562 def get_existing_columns(self, table_name: str) -> set[str]:
563 """Get set of existing column names in a table."""
564 query = f'PRAGMA table_info("{table_name}")'
565 results = self.fetchall(query)
566 # PRAGMA table_info returns: (cid, name, type, notnull, dflt_value, pk)
567 return {row[1].lower() for row in results}
569 def add_column(self, table_name: str, column_name: str, column_type: str) -> None:
570 """Add a new column to an existing table."""
571 query = f'ALTER TABLE "{table_name}" ADD COLUMN "{column_name}" {column_type}'
572 self.execute(query)
573 self.commit()
575 def upsert_row(
576 self, table_name: str, conflict_columns: list[str], row_data: dict[str, Any]
577 ) -> None:
578 """Upsert a row into the database."""
579 columns = list(row_data.keys())
580 values = tuple(row_data.values())
582 columns_str = ", ".join(f'"{col}"' for col in columns)
583 placeholders = ", ".join("?" * len(values))
584 update_str = ", ".join(
585 f'"{col}" = excluded."{col}"' for col in columns if col not in conflict_columns
586 )
588 # SQLite ON CONFLICT clause with multiple columns
589 conflict_cols_str = ", ".join(f'"{col}"' for col in conflict_columns)
591 query = f'INSERT INTO "{table_name}" ({columns_str}) VALUES ({placeholders}) '
592 query += f"ON CONFLICT ({conflict_cols_str}) DO UPDATE SET {update_str}"
594 self.execute(query, values)
595 self.commit()
597 def get_existing_indexes(self, table_name: str) -> set[str]:
598 """Get set of existing index names for a table."""
599 query = "SELECT name FROM sqlite_master WHERE type='index' AND tbl_name=?"
600 results = self.fetchall(query, (table_name,))
601 return {row[0].lower() for row in results}
603 def create_index(
604 self, table_name: str, index_name: str, columns: list[tuple[str, str]]
605 ) -> None:
606 """Create an index on the specified columns."""
607 # Build column list with order
608 column_parts = []
609 for col_name, order in columns:
610 column_parts.append(f'"{col_name}" {order}')
612 columns_str = ", ".join(column_parts)
613 query = f'CREATE INDEX IF NOT EXISTS "{index_name}" ON "{table_name}" ({columns_str})'
615 self.execute(query)
616 self.commit()
618 def table_exists(self, table_name: str) -> bool:
619 """Check if a table exists in the database.
621 Args:
622 table_name: Name of the table to check
624 Returns:
625 True if table exists, False otherwise
626 """
627 query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
628 result = self.fetchall(query, (table_name,))
629 return len(result) > 0
631 def delete_stale_records_compound(
632 self,
633 table_name: str,
634 id_columns: list[str],
635 filter_columns: dict[str, str],
636 current_ids: set[tuple],
637 ) -> int:
638 """Delete records from database that aren't in current CSV using compound filter key.
640 Args:
641 table_name: Name of the table
642 id_columns: List of ID column names (for compound keys)
643 filter_columns: Dictionary of column_name -> value to filter by (compound key)
644 current_ids: Set of ID tuples from the current CSV
646 Returns:
647 Count of records deleted
648 """
649 if not current_ids or not filter_columns:
650 return 0
652 # Build WHERE clause: WHERE col1 = ? AND col2 = ? AND (id1, id2) NOT IN (...)
653 filter_conditions = [f'"{col}" = ?' for col in filter_columns]
655 if len(id_columns) == 1:
656 # Single key - simpler query
657 current_ids_list = [
658 id_val[0] if isinstance(id_val, tuple) else id_val for id_val in current_ids
659 ]
660 placeholders = ", ".join("?" * len(current_ids_list))
661 count_query = f"""
662 SELECT COUNT(*) FROM "{table_name}"
663 WHERE {" AND ".join(filter_conditions)}
664 AND "{id_columns[0]}" NOT IN ({placeholders})
665 """
666 delete_query = f"""
667 DELETE FROM "{table_name}"
668 WHERE {" AND ".join(filter_conditions)}
669 AND "{id_columns[0]}" NOT IN ({placeholders})
670 """
671 params = tuple(list(filter_columns.values()) + current_ids_list)
672 else:
673 # Compound key - use row value constructor
674 quoted_cols = [f'"{col}"' for col in id_columns]
675 id_cols = f"({', '.join(quoted_cols)})"
676 placeholders = ", ".join(f"({', '.join('?' * len(id_columns))})" for _ in current_ids)
677 count_query = f"""
678 SELECT COUNT(*) FROM "{table_name}"
679 WHERE {" AND ".join(filter_conditions)}
680 AND {id_cols} NOT IN ({placeholders})
681 """
682 delete_query = f"""
683 DELETE FROM "{table_name}"
684 WHERE {" AND ".join(filter_conditions)}
685 AND {id_cols} NOT IN ({placeholders})
686 """
687 # Flatten the list of tuples for params
688 id_params = [val for id_tuple in current_ids for val in id_tuple]
689 params = tuple(list(filter_columns.values()) + id_params)
691 # Count first
692 logger.debug(f"SQLite count query: {count_query}")
693 logger.debug(f"SQLite count params: {params}")
694 count_result = self.fetchall(count_query, params)
695 deleted_count = count_result[0][0] if count_result else 0
697 # Delete stale records
698 logger.debug(f"SQLite delete query: {delete_query}")
699 logger.debug(f"SQLite delete params: {params}")
700 logger.debug(f"SQLite deleted count: {deleted_count}")
701 self.execute(delete_query, params)
702 self.commit()
704 return deleted_count
706 def count_stale_records_compound(
707 self,
708 table_name: str,
709 id_columns: list[str],
710 filter_columns: dict[str, str],
711 current_ids: set[tuple],
712 ) -> int:
713 """Count records that would be deleted using compound filter key.
715 Args:
716 table_name: Name of the table
717 id_columns: List of ID column names (for compound keys)
718 filter_columns: Dictionary of column_name -> value to filter by (compound key)
719 current_ids: Set of ID tuples from the current CSV
721 Returns:
722 Count of records that would be deleted
723 """
724 if not current_ids or not filter_columns:
725 return 0
727 # Build WHERE clause: WHERE col1 = ? AND col2 = ? AND (id1, id2) NOT IN (...)
728 filter_conditions = [f'"{col}" = ?' for col in filter_columns]
730 if len(id_columns) == 1:
731 # Single key - simpler query
732 current_ids_list = [
733 id_val[0] if isinstance(id_val, tuple) else id_val for id_val in current_ids
734 ]
735 placeholders = ", ".join("?" * len(current_ids_list))
736 count_query = f"""
737 SELECT COUNT(*) FROM "{table_name}"
738 WHERE {" AND ".join(filter_conditions)}
739 AND "{id_columns[0]}" NOT IN ({placeholders})
740 """
741 params = tuple(list(filter_columns.values()) + current_ids_list)
742 else:
743 # Compound key - use row value constructor
744 quoted_cols = [f'"{col}"' for col in id_columns]
745 id_cols = f"({', '.join(quoted_cols)})"
746 placeholders = ", ".join(f"({', '.join('?' * len(id_columns))})" for _ in current_ids)
747 count_query = f"""
748 SELECT COUNT(*) FROM "{table_name}"
749 WHERE {" AND ".join(filter_conditions)}
750 AND {id_cols} NOT IN ({placeholders})
751 """
752 # Flatten the list of tuples for params
753 id_params = [val for id_tuple in current_ids for val in id_tuple]
754 params = tuple(list(filter_columns.values()) + id_params)
756 count_result = self.fetchall(count_query, params)
757 return count_result[0][0] if count_result else 0
760class DatabaseConnection:
761 """Database connection handler supporting PostgreSQL and SQLite."""
763 def __init__(self, connection_string: str) -> None:
764 """Initialize database connection.
766 Args:
767 connection_string: Database connection string
768 - PostgreSQL: postgresql://user:pass@host:port/db
769 - SQLite: sqlite:///path/to/db.db or sqlite:///:memory:
770 """
771 self.connection_string = connection_string
772 self.backend: DatabaseBackend | None = None
774 def __enter__(self) -> DatabaseConnection:
775 """Enter context manager."""
776 if self.connection_string.startswith("sqlite"):
777 self.backend = SQLiteBackend(self.connection_string)
778 elif self.connection_string.startswith("postgres"):
779 self.backend = PostgreSQLBackend(self.connection_string)
780 else:
781 raise ValueError(
782 f"Unsupported database type in connection string: {self.connection_string}"
783 )
784 return self
786 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
787 """Exit context manager."""
788 if self.backend:
789 self.backend.close()
791 def create_table_if_not_exists(
792 self, table_name: str, columns: dict[str, str], primary_keys: list[str] | None = None
793 ) -> None:
794 """Create table if it doesn't exist."""
795 if not self.backend:
796 raise RuntimeError("Database connection not established")
797 self.backend.create_table_if_not_exists(table_name, columns, primary_keys)
799 def get_existing_columns(self, table_name: str) -> set[str]:
800 """Get set of existing column names in a table."""
801 if not self.backend:
802 raise RuntimeError("Database connection not established")
803 return self.backend.get_existing_columns(table_name)
805 def add_column(self, table_name: str, column_name: str, column_type: str) -> None:
806 """Add a new column to an existing table."""
807 if not self.backend:
808 raise RuntimeError("Database connection not established")
809 self.backend.add_column(table_name, column_name, column_type)
811 def upsert_row(
812 self, table_name: str, conflict_columns: list[str], row_data: dict[str, Any]
813 ) -> None:
814 """Upsert a row into the database."""
815 if not self.backend:
816 raise RuntimeError("Database connection not established")
817 self.backend.upsert_row(table_name, conflict_columns, row_data)
819 def delete_stale_records_compound(
820 self,
821 table_name: str,
822 id_columns: list[str],
823 filter_columns: dict[str, str],
824 current_ids: set[tuple],
825 ) -> int:
826 """Delete records from database that aren't in current CSV using compound filter key."""
827 if not self.backend:
828 raise RuntimeError("Database connection not established")
829 return self.backend.delete_stale_records_compound(
830 table_name, id_columns, filter_columns, current_ids
831 )
833 def count_stale_records_compound(
834 self,
835 table_name: str,
836 id_columns: list[str],
837 filter_columns: dict[str, str],
838 current_ids: set[tuple],
839 ) -> int:
840 """Count records that would be deleted using compound filter key."""
841 if not self.backend:
842 raise RuntimeError("Database connection not established")
843 return self.backend.count_stale_records_compound(
844 table_name, id_columns, filter_columns, current_ids
845 )
847 def get_existing_indexes(self, table_name: str) -> set[str]:
848 """Get set of existing index names for a table."""
849 if not self.backend:
850 raise RuntimeError("Database connection not established")
851 return self.backend.get_existing_indexes(table_name)
853 def create_index(
854 self, table_name: str, index_name: str, columns: list[tuple[str, str]]
855 ) -> None:
856 """Create an index on the specified columns."""
857 if not self.backend:
858 raise RuntimeError("Database connection not established")
859 self.backend.create_index(table_name, index_name, columns)
861 def table_exists(self, table_name: str) -> bool:
862 """Check if a table exists in the database.
864 Args:
865 table_name: Name of the table to check
867 Returns:
868 True if table exists, False otherwise
869 """
870 if not self.backend:
871 raise RuntimeError("Database connection not established")
872 return self.backend.table_exists(table_name)
874 def _validate_id_columns(self, job: CrumpJob, csv_columns: set[str]) -> set[str]:
875 """Validate that required ID columns exist in CSV.
877 Args:
878 job: CrumpJob configuration
879 csv_columns: Set of column names from CSV
881 Returns:
882 Set of ID column names from CSV
884 Raises:
885 ValueError: If any ID column is missing from CSV
886 """
887 id_csv_columns = set()
888 for id_col in job.id_mapping:
889 # Skip validation for custom functions (no csv_column)
890 if id_col.csv_column is None:
891 # Custom function - validate input columns instead
892 if id_col.input_columns:
893 for input_col in id_col.input_columns:
894 if input_col not in csv_columns:
895 raise ValueError(
896 f"Input column '{input_col}' for custom function "
897 f"'{id_col.db_column}' not found in CSV"
898 )
899 continue
901 if id_col.csv_column not in csv_columns:
902 raise ValueError(f"ID column '{id_col.csv_column}' not found in CSV")
903 id_csv_columns.add(id_col.csv_column)
904 return id_csv_columns
906 def _determine_sync_columns(
907 self, job: CrumpJob, csv_columns: set[str], id_csv_columns: set[str]
908 ) -> list[Any]:
909 """Determine which columns to sync based on job configuration.
911 When failure_mode is set, missing CSV columns for configured mappings are
912 tolerated (the column is kept so rows can receive default/null values).
913 Custom function input columns that are missing always raise ValueError.
915 Args:
916 job: CrumpJob configuration
917 csv_columns: Set of column names from CSV
918 id_csv_columns: Set of ID column names
920 Returns:
921 List of ColumnMapping objects for columns to sync
923 Raises:
924 ValueError: If a custom function input column is missing from CSV
925 """
926 if job.columns:
927 # Specific columns defined
928 sync_columns = list(job.id_mapping) + job.columns
929 for col_mapping in job.columns:
930 # Skip validation for custom functions (no csv_column)
931 if col_mapping.csv_column is None:
932 # Custom function - validate input columns instead
933 if col_mapping.input_columns:
934 for input_col in col_mapping.input_columns:
935 if input_col not in csv_columns:
936 raise ValueError(
937 f"Input column '{input_col}' for custom function "
938 f"'{col_mapping.db_column}' not found in CSV"
939 )
940 continue
942 if col_mapping.csv_column not in csv_columns:
943 # Column is missing from CSV - log warning but keep it
944 # Row validation will handle this per-row based on failure_mode
945 logger.warning(
946 f"Column '{col_mapping.csv_column}' defined in config "
947 f"but not found in CSV file"
948 )
949 else:
950 # Sync all columns
951 sync_columns = list(job.id_mapping)
952 for csv_col in csv_columns:
953 if csv_col not in id_csv_columns:
954 sync_columns.append(ColumnMapping(csv_col, csv_col))
956 return sync_columns
958 def _build_column_definitions(self, sync_columns: list[Any], job: CrumpJob) -> dict[str, str]:
959 """Build column definitions with SQL types and nullable constraints.
961 Args:
962 sync_columns: List of ColumnMapping objects
963 job: CrumpJob configuration
965 Returns:
966 Dictionary mapping column names to SQL type definitions (including NULL/NOT NULL)
967 """
968 if not self.backend:
969 raise RuntimeError("Database connection not established")
970 columns_def = {}
971 for col_mapping in sync_columns:
972 sql_type = self.backend.map_data_type(col_mapping.data_type)
974 # Add nullable constraint if specified
975 if col_mapping.nullable is not None:
976 if col_mapping.nullable:
977 sql_type += " NULL"
978 else:
979 sql_type += " NOT NULL"
981 columns_def[col_mapping.db_column] = sql_type
983 # Add filename_to_column columns if configured
984 if job.filename_to_column:
985 for col_mapping in job.filename_to_column.columns.values():
986 sql_type = self.backend.map_data_type(col_mapping.data_type)
987 columns_def[col_mapping.db_column] = sql_type
989 return columns_def
991 def _setup_table_schema(
992 self, job: CrumpJob, columns_def: dict[str, str], primary_keys: list[str]
993 ) -> bool:
994 """Create table and add missing columns/indexes.
996 Args:
997 job: CrumpJob configuration
998 columns_def: Dictionary mapping column names to SQL types
999 primary_keys: List of primary key column names
1001 Returns:
1002 True if schema changes were made (table created, columns added, or indexes created)
1003 """
1004 schema_changed = False
1006 # Check if table exists before creating
1007 table_existed = self.table_exists(job.target_table)
1009 # Create table if it doesn't exist
1010 self.create_table_if_not_exists(job.target_table, columns_def, primary_keys)
1012 if not table_existed:
1013 schema_changed = True
1015 # Check for schema evolution: add missing columns from config
1016 existing_columns = self.get_existing_columns(job.target_table)
1017 for col_name, col_type in columns_def.items():
1018 if col_name.lower() not in existing_columns:
1019 self.add_column(job.target_table, col_name, col_type)
1020 schema_changed = True
1022 # Create indexes that don't already exist
1023 if job.indexes:
1024 existing_indexes = self.get_existing_indexes(job.target_table)
1025 for index in job.indexes:
1026 if index.name.lower() not in existing_indexes:
1027 index_columns = [(col.column, col.order) for col in index.columns]
1028 self.create_index(job.target_table, index.name, index_columns)
1029 schema_changed = True
1031 return schema_changed
1033 def _should_include_row(
1034 self, row_index: int, total_rows: int, sample_percentage: float | None
1035 ) -> bool:
1036 """Determine if a row should be included based on sampling percentage.
1038 Args:
1039 row_index: Zero-based index of the current row
1040 total_rows: Total number of rows in the dataset
1041 sample_percentage: Optional percentage of rows to sample (0-100)
1043 Returns:
1044 True if row should be included, False otherwise
1045 """
1046 # If no sampling or 100%, include all rows
1047 if sample_percentage is None or sample_percentage >= 100:
1048 return True
1050 # If 0%, exclude all rows (edge case)
1051 if sample_percentage <= 0:
1052 return False
1054 # Always include first row
1055 if row_index == 0:
1056 return True
1058 # Always include last row
1059 if row_index == total_rows - 1:
1060 return True
1062 # Sample other rows based on percentage
1063 # For 10%, interval = 10, so include rows 0, 10, 20, 30...
1064 # For 25%, interval = 4, so include rows 0, 4, 8, 12...
1065 interval = int(100 / sample_percentage)
1066 return row_index % interval == 0
1068 @staticmethod
1069 def _get_varchar_limit(data_type: str | None) -> int | None:
1070 """Extract the character limit from a varchar(N) type string.
1072 Args:
1073 data_type: Data type string, e.g. 'varchar(50)'
1075 Returns:
1076 The limit N, or None if not a varchar type
1077 """
1078 if data_type is None:
1079 return None
1080 import re as _re
1082 match = _re.match(r"varchar\((\d+)\)", data_type.lower().strip())
1083 if match:
1084 return int(match.group(1))
1085 return None
1087 # PostgreSQL integer range limits
1088 _INTEGER_MIN = -2147483648
1089 _INTEGER_MAX = 2147483647
1090 _BIGINT_MIN = -9223372036854775808
1091 _BIGINT_MAX = 9223372036854775807
1093 # Minimum datetime used as a permissive default for non-nullable datetime columns
1094 _MIN_DATETIME = datetime.datetime(1, 1, 1, 0, 0, 0)
1095 _MIN_DATE = datetime.date(1, 1, 1)
1097 @staticmethod
1098 def _get_integer_range(data_type: str | None) -> tuple[int, int] | None:
1099 """Get the valid integer range for the given data type.
1101 Args:
1102 data_type: Data type string, e.g. 'integer', 'bigint'
1104 Returns:
1105 (min, max) tuple, or None if not an integer type
1106 """
1107 if data_type is None:
1108 return None
1109 dt_lower = data_type.lower().strip()
1110 if dt_lower in ("integer", "int"):
1111 return (DatabaseConnection._INTEGER_MIN, DatabaseConnection._INTEGER_MAX)
1112 if dt_lower == "bigint":
1113 return (DatabaseConnection._BIGINT_MIN, DatabaseConnection._BIGINT_MAX)
1114 return None
1116 @staticmethod
1117 def _is_datetime_type(data_type: str | None) -> bool:
1118 """Check if the data type is a date or datetime type.
1120 Args:
1121 data_type: Data type string
1123 Returns:
1124 True if data_type is date, datetime, or timestamp
1125 """
1126 if data_type is None:
1127 return False
1128 return data_type.lower().strip() in ("date", "datetime", "timestamp")
1130 @staticmethod
1131 def _is_empty_datetime_value(value: Any) -> bool:
1132 """Check if a value represents an empty/null datetime.
1134 Args:
1135 value: The value to check
1137 Returns:
1138 True if the value is None, empty string, or whitespace-only string
1139 """
1140 if value is None:
1141 return True
1142 return isinstance(value, str) and value.strip() == ""
1144 @staticmethod
1145 def _get_default_value(data_type: str | None) -> Any:
1146 """Get the permissive default value for a non-nullable column.
1148 Args:
1149 data_type: The configured data type
1151 Returns:
1152 0 for integer/numeric types, min datetime for date/datetime types,
1153 empty string for text/string types
1154 """
1155 if data_type is None:
1156 return ""
1157 dt_lower = data_type.lower().strip()
1158 if dt_lower in ("integer", "int", "bigint"):
1159 return 0
1160 if dt_lower in ("float", "double"):
1161 return 0.0
1162 if dt_lower == "date":
1163 return DatabaseConnection._MIN_DATE
1164 if dt_lower in ("datetime", "timestamp"):
1165 return DatabaseConnection._MIN_DATETIME
1166 return ""
1168 def _validate_and_fix_row(
1169 self,
1170 row_data: dict[str, Any],
1171 sync_columns: list[Any],
1172 job: CrumpJob,
1173 csv_row: dict[str, Any],
1174 ) -> dict[str, Any] | None:
1175 """Validate a transformed row and apply failure_mode rules.
1177 Handles:
1178 - Missing nullable fields → NULL (both modes)
1179 - Missing non-nullable fields → skip row (STRICT), default value (PERMISSIVE)
1180 - String exceeding varchar limit → skip row (STRICT), truncate (PERMISSIVE)
1181 - Integer out of range → skip row (STRICT), NULL if nullable else skip (PERMISSIVE)
1182 - Empty/null datetime → NULL if nullable, min datetime (PERMISSIVE), skip (STRICT)
1184 Args:
1185 row_data: The transformed row data (db_column → value)
1186 sync_columns: List of ColumnMapping objects
1187 job: CrumpJob configuration
1188 csv_row: The original CSV row (for context in logging)
1190 Returns:
1191 The validated/fixed row_data dict, or None if the row should be skipped
1192 """
1193 failure_mode = job.failure_mode
1195 for col_mapping in sync_columns:
1196 db_col = col_mapping.db_column
1198 # Determine if this column's value is missing from the CSV
1199 # A value is "missing" if:
1200 # - The db_col key is absent from row_data, OR
1201 # - The value is None (set by apply_row_transformations for missing CSV cols), OR
1202 # - The CSV column was not present in the original row (empty string artifact)
1203 value = row_data.get(db_col)
1204 is_missing = (
1205 db_col not in row_data
1206 or value is None
1207 or (
1208 value == ""
1209 and col_mapping.csv_column is not None
1210 and col_mapping.csv_column not in csv_row
1211 )
1212 )
1214 if is_missing:
1215 if col_mapping.nullable is False:
1216 # Non-nullable field missing
1217 if failure_mode == FailureMode.STRICT:
1218 logger.warning(
1219 f"STRICT mode: Skipping row - missing non-nullable field '{db_col}'"
1220 )
1221 return None
1222 else:
1223 # PERMISSIVE: use default value
1224 default = self._get_default_value(col_mapping.data_type)
1225 logger.warning(
1226 f"PERMISSIVE mode: Using default value {default!r} "
1227 f"for missing non-nullable field '{db_col}'"
1228 )
1229 row_data[db_col] = default
1230 else:
1231 # Nullable or unspecified → NULL
1232 row_data[db_col] = None
1234 # Check varchar limit
1235 varchar_limit = self._get_varchar_limit(col_mapping.data_type)
1236 if varchar_limit is not None and db_col in row_data and row_data[db_col] is not None:
1237 str_value = str(row_data[db_col])
1238 if len(str_value) > varchar_limit:
1239 if failure_mode == FailureMode.STRICT:
1240 logger.warning(
1241 f"STRICT mode: Skipping row - value for '{db_col}' "
1242 f"exceeds varchar({varchar_limit}) limit "
1243 f"(length {len(str_value)})"
1244 )
1245 return None
1246 else:
1247 # PERMISSIVE: truncate
1248 logger.warning(
1249 f"PERMISSIVE mode: Truncating value for '{db_col}' "
1250 f"from {len(str_value)} to {varchar_limit} characters"
1251 )
1252 row_data[db_col] = str_value[:varchar_limit]
1254 # Check integer range
1255 int_range = self._get_integer_range(col_mapping.data_type)
1256 if int_range is not None and db_col in row_data and row_data[db_col] is not None:
1257 try:
1258 int_value = int(row_data[db_col])
1259 except (ValueError, TypeError):
1260 int_value = None
1262 if int_value is not None and (int_value < int_range[0] or int_value > int_range[1]):
1263 type_name = col_mapping.data_type or "integer"
1264 if failure_mode == FailureMode.STRICT:
1265 logger.warning(
1266 f"STRICT mode: Skipping row - value {int_value} for "
1267 f"'{db_col}' is out of {type_name} range "
1268 f"[{int_range[0]}, {int_range[1]}]"
1269 )
1270 return None
1271 else:
1272 # PERMISSIVE: use NULL if nullable, otherwise skip
1273 if col_mapping.nullable is not False:
1274 logger.warning(
1275 f"PERMISSIVE mode: Setting '{db_col}' to NULL - "
1276 f"value {int_value} is out of {type_name} range"
1277 )
1278 row_data[db_col] = None
1279 else:
1280 logger.warning(
1281 f"PERMISSIVE mode: Skipping row - value {int_value} "
1282 f"for non-nullable '{db_col}' is out of {type_name} "
1283 f"range and cannot be set to NULL"
1284 )
1285 return None
1287 # Check datetime empty/null values
1288 if (
1289 self._is_datetime_type(col_mapping.data_type)
1290 and db_col in row_data
1291 and self._is_empty_datetime_value(row_data[db_col])
1292 ):
1293 if col_mapping.nullable is not False:
1294 row_data[db_col] = None
1295 elif failure_mode == FailureMode.STRICT:
1296 logger.warning(
1297 f"STRICT mode: Skipping row - empty datetime value "
1298 f"for non-nullable field '{db_col}'"
1299 )
1300 return None
1301 else:
1302 # PERMISSIVE: use minimum datetime
1303 default = self._get_default_value(col_mapping.data_type)
1304 logger.warning(
1305 f"PERMISSIVE mode: Using minimum datetime {default!r} "
1306 f"for empty non-nullable field '{db_col}'"
1307 )
1308 row_data[db_col] = default
1310 return row_data
1312 def _process_tabular_rows(
1313 self,
1314 reader: Any,
1315 job: CrumpJob,
1316 sync_columns: list[Any],
1317 primary_keys: list[str],
1318 filename_values: dict[str, str] | None = None,
1319 ) -> tuple[int, set[tuple]]:
1320 """Process and upsert tabular file rows into database.
1322 Args:
1323 reader: Tabular file reader (DictReader interface)
1324 job: CrumpJob configuration
1325 sync_columns: List of ColumnMapping objects
1326 primary_keys: List of primary key column names
1327 filename_values: Optional dict of values extracted from filename
1329 Returns:
1330 Tuple of (rows_synced, synced_ids) where synced_ids are tuples of ID values
1331 """
1332 rows_synced = 0
1333 rows_skipped = 0
1334 synced_ids: set[tuple] = set()
1336 # For sampling, we need to know total row count first
1337 if job.sample_percentage is not None and job.sample_percentage < 100:
1338 # Read all rows into memory to get total count and apply sampling
1339 all_rows = list(reader)
1340 total_rows = len(all_rows)
1342 for row_index, row in enumerate(all_rows):
1343 # Check if this row should be included
1344 if not self._should_include_row(row_index, total_rows, job.sample_percentage):
1345 continue
1347 # Apply column transformations
1348 row_data = apply_row_transformations(
1349 row, sync_columns, job.filename_to_column, filename_values
1350 )
1352 # Validate and fix row based on failure_mode
1353 validated = self._validate_and_fix_row(row_data, sync_columns, job, row)
1354 if validated is None:
1355 rows_skipped += 1
1356 continue
1358 self.upsert_row(job.target_table, primary_keys, validated)
1360 # Track synced IDs as tuples (for compound key support)
1361 id_values = tuple(validated[id_col.db_column] for id_col in job.id_mapping)
1362 synced_ids.add(id_values)
1363 rows_synced += 1
1364 else:
1365 # No sampling - process rows normally without loading into memory
1366 for row in reader:
1367 # Apply column transformations
1368 row_data = apply_row_transformations(
1369 row, sync_columns, job.filename_to_column, filename_values
1370 )
1372 # Validate and fix row based on failure_mode
1373 validated = self._validate_and_fix_row(row_data, sync_columns, job, row)
1374 if validated is None:
1375 rows_skipped += 1
1376 continue
1378 self.upsert_row(job.target_table, primary_keys, validated)
1380 # Track synced IDs as tuples (for compound key support)
1381 id_values = tuple(validated[id_col.db_column] for id_col in job.id_mapping)
1382 synced_ids.add(id_values)
1383 rows_synced += 1
1385 if rows_skipped > 0:
1386 logger.warning(f"Skipped {rows_skipped} rows due to validation failures")
1388 # In STRICT mode, if the file had rows but ALL were rejected, raise an error
1389 if job.failure_mode == FailureMode.STRICT and rows_skipped > 0 and rows_synced == 0:
1390 raise ValueError(
1391 f"STRICT mode: All {rows_skipped} row(s) were rejected due to "
1392 f"validation failures. No data was imported into '{job.target_table}'."
1393 )
1395 return rows_synced, synced_ids
1397 def _count_and_track_tabular_rows(
1398 self,
1399 file_path: Path,
1400 job: CrumpJob,
1401 sync_columns: list[Any],
1402 filename_values: dict[str, str] | None = None,
1403 ) -> tuple[int, set[tuple]]:
1404 """Count CSV rows and track synced IDs without database operations.
1406 This helper method processes the CSV to count rows and collect IDs that would be synced,
1407 which is shared logic between dry-run and actual sync operations.
1409 Args:
1410 file_path: Path to tabular file (CSV or Parquet)
1411 job: CrumpJob configuration
1412 sync_columns: List of ColumnMapping objects
1413 filename_values: Optional dict of values extracted from filename
1415 Returns:
1416 Tuple of (row_count, synced_ids) where synced_ids are tuples of ID values
1417 """
1418 row_count = 0
1419 synced_ids: set[tuple] = set()
1421 file_format = _detect_file_format(file_path)
1423 with create_reader(file_path, file_format=file_format) as reader:
1424 # For sampling, we need to know total row count first
1425 if job.sample_percentage is not None and job.sample_percentage < 100:
1426 # Read all rows into memory to get total count and apply sampling
1427 all_rows = list(reader)
1428 total_rows = len(all_rows)
1430 for row_index, row in enumerate(all_rows):
1431 # Check if this row should be included
1432 if not self._should_include_row(row_index, total_rows, job.sample_percentage):
1433 continue
1435 # Apply column transformations
1436 row_data = apply_row_transformations(
1437 row, sync_columns, job.filename_to_column, filename_values
1438 )
1440 # Track synced IDs as tuples (for compound key support)
1441 id_values = tuple(row_data[id_col.db_column] for id_col in job.id_mapping)
1442 synced_ids.add(id_values)
1443 row_count += 1
1444 else:
1445 # No sampling - process rows normally
1446 for row in reader:
1447 # Apply column transformations
1448 row_data = apply_row_transformations(
1449 row, sync_columns, job.filename_to_column, filename_values
1450 )
1452 # Track synced IDs as tuples (for compound key support)
1453 id_values = tuple(row_data[id_col.db_column] for id_col in job.id_mapping)
1454 synced_ids.add(id_values)
1455 row_count += 1
1457 return row_count, synced_ids
1459 def _prepare_sync(
1460 self, file_path: Path, job: CrumpJob
1461 ) -> tuple[set[str], list[Any], dict[str, str]]:
1462 """Prepare for sync by validating CSV and building schema definitions.
1464 Args:
1465 file_path: Path to tabular file (CSV or Parquet)
1466 job: CrumpJob configuration
1468 Returns:
1469 Tuple of (csv_columns, sync_columns, columns_def)
1471 Raises:
1472 FileNotFoundError: If CSV file doesn't exist
1473 ValueError: If CSV is invalid or columns don't match
1474 """
1475 if not file_path.exists():
1476 raise FileNotFoundError(f"File not found: {file_path}")
1478 file_format = _detect_file_format(file_path)
1480 with create_reader(file_path, file_format=file_format) as reader:
1481 if not reader.fieldnames:
1482 raise ValueError("File has no columns")
1483 csv_columns = set(reader.fieldnames)
1485 # Validate and determine columns to sync
1486 id_csv_columns = self._validate_id_columns(job, csv_columns)
1487 sync_columns = self._determine_sync_columns(job, csv_columns, id_csv_columns)
1489 # Build schema definitions
1490 columns_def = self._build_column_definitions(sync_columns, job)
1492 return csv_columns, sync_columns, columns_def
1494 def sync_tabular_file_dry_run(
1495 self,
1496 file_path: Path,
1497 job: CrumpJob,
1498 filename_values: dict[str, str] | None = None,
1499 ) -> DryRunSummary:
1500 """Simulate syncing a CSV file without making database changes.
1502 Args:
1503 file_path: Path to tabular file (CSV or Parquet)
1504 job: CrumpJob configuration
1505 filename_values: Optional dict of values extracted from filename
1507 Returns:
1508 DryRunSummary with details of what would be changed
1510 Raises:
1511 FileNotFoundError: If CSV file doesn't exist
1512 ValueError: If CSV is invalid or columns don't match
1513 """
1514 summary = DryRunSummary()
1515 summary.table_name = job.target_table
1517 # Prepare sync (validates CSV and builds schema)
1518 csv_columns, sync_columns, columns_def = self._prepare_sync(file_path, job)
1520 # Check what schema changes would be made
1521 summary.table_exists = self.table_exists(job.target_table)
1523 if summary.table_exists:
1524 # Check for new columns
1525 existing_columns = self.get_existing_columns(job.target_table)
1526 for col_name, col_type in columns_def.items():
1527 if col_name.lower() not in existing_columns:
1528 summary.new_columns.append((col_name, col_type))
1530 # Check for new indexes
1531 if job.indexes:
1532 existing_indexes = self.get_existing_indexes(job.target_table)
1533 for index in job.indexes:
1534 if index.name.lower() not in existing_indexes:
1535 summary.new_indexes.append(index.name)
1537 # Count rows and track IDs that would be synced
1538 # NOTE: This counts all CSV rows, even if they match existing data.
1539 # A more accurate implementation would query existing data and compare,
1540 # but that would be expensive for large datasets. For now, we report
1541 # the upper bound of rows that could be updated.
1542 # If there are new columns, all rows will need updating regardless.
1543 summary.rows_to_sync, synced_ids = self._count_and_track_tabular_rows(
1544 file_path, job, sync_columns, filename_values
1545 )
1547 # Count stale records that would be deleted
1548 if job.filename_to_column and filename_values and summary.table_exists:
1549 delete_key_columns = job.filename_to_column.get_delete_key_columns()
1550 if delete_key_columns:
1551 # Build compound key values from filename_values
1552 delete_key_values = {}
1553 for col_name, col_mapping in job.filename_to_column.columns.items():
1554 if col_mapping.use_to_delete_old_rows and col_name in filename_values:
1555 delete_key_values[col_mapping.db_column] = filename_values[col_name]
1557 id_columns = [id_col.db_column for id_col in job.id_mapping]
1558 summary.rows_to_delete = self.count_stale_records_compound(
1559 job.target_table,
1560 id_columns,
1561 delete_key_values,
1562 synced_ids,
1563 )
1565 return summary
1567 def sync_tabular_file(
1568 self,
1569 file_path: Path,
1570 job: CrumpJob,
1571 filename_values: dict[str, str] | None = None,
1572 enable_history: bool = False,
1573 ) -> int:
1574 """Sync a CSV file to the database using job configuration.
1576 Args:
1577 file_path: Path to tabular file (CSV or Parquet)
1578 job: CrumpJob configuration
1579 filename_values: Optional dict of values extracted from filename
1580 enable_history: Whether to record sync history
1582 Returns:
1583 Number of rows synced
1585 Raises:
1586 FileNotFoundError: If CSV file doesn't exist
1587 ValueError: If CSV is invalid or columns don't match
1588 """
1589 from crump.history import get_utc_now, record_sync_history
1591 # Track timing if history is enabled
1592 start_time = get_utc_now() if enable_history else None
1593 rows_deleted = 0
1594 schema_changed = False
1595 error_message: str | None = None
1596 success = False
1598 try:
1599 # Prepare sync (validates CSV and builds schema)
1600 csv_columns, sync_columns, columns_def = self._prepare_sync(file_path, job)
1602 # Build schema and setup table
1603 primary_keys = [id_col.db_column for id_col in job.id_mapping]
1604 logger.debug(f"Primary keys for table {job.target_table}: {primary_keys}")
1605 schema_changed = self._setup_table_schema(job, columns_def, primary_keys)
1607 # Process rows
1608 file_format = _detect_file_format(file_path)
1609 with create_reader(file_path, file_format=file_format) as reader:
1610 rows_synced, synced_ids = self._process_tabular_rows(
1611 reader, job, sync_columns, primary_keys, filename_values
1612 )
1614 # Clean up stale records
1615 if job.filename_to_column and filename_values:
1616 delete_key_columns = job.filename_to_column.get_delete_key_columns()
1617 if delete_key_columns:
1618 # Build compound key values from filename_values
1619 delete_key_values = {}
1620 for col_name, col_mapping in job.filename_to_column.columns.items():
1621 if col_mapping.use_to_delete_old_rows and col_name in filename_values:
1622 delete_key_values[col_mapping.db_column] = filename_values[col_name]
1624 id_columns = [id_col.db_column for id_col in job.id_mapping]
1625 rows_deleted = self.delete_stale_records_compound(
1626 job.target_table,
1627 id_columns,
1628 delete_key_values,
1629 synced_ids,
1630 )
1632 success = True
1633 return rows_synced
1635 except Exception as e:
1636 error_message = str(e)
1637 raise
1639 finally:
1640 # Record history if enabled and we have a backend
1641 if enable_history and self.backend and start_time:
1642 end_time = get_utc_now()
1643 # If sync failed, rows_synced might not be set
1644 final_rows_synced = rows_synced if success else 0
1645 try:
1646 record_sync_history(
1647 backend=self.backend,
1648 file_path=file_path,
1649 table_name=job.target_table,
1650 rows_upserted=final_rows_synced,
1651 rows_deleted=rows_deleted,
1652 schema_changed=schema_changed,
1653 start_time=start_time,
1654 end_time=end_time,
1655 success=success,
1656 error=error_message,
1657 )
1658 except Exception as hist_error:
1659 # Don't fail the sync if history recording fails
1660 logger.warning(f"Failed to record sync history: {hist_error}")
1663def sync_file_to_db(
1664 file_path: Path,
1665 job: CrumpJob,
1666 db_connection_string: str,
1667 filename_values: dict[str, str] | None = None,
1668 enable_history: bool = False,
1669) -> int:
1670 """Sync a tabular file (CSV or Parquet) to database.
1672 Args:
1673 file_path: Path to the tabular file (CSV or Parquet)
1674 job: CrumpJob configuration
1675 db_connection_string: Database connection string (PostgreSQL or SQLite)
1676 filename_values: Optional dict of values extracted from filename
1677 enable_history: Whether to record sync history
1679 Returns:
1680 Number of rows synced
1681 """
1682 with DatabaseConnection(db_connection_string) as db:
1683 return db.sync_tabular_file(file_path, job, filename_values, enable_history)
1686def sync_file_to_db_dry_run(
1687 file_path: Path,
1688 job: CrumpJob,
1689 db_connection_string: str,
1690 filename_values: dict[str, str] | None = None,
1691) -> DryRunSummary:
1692 """Simulate syncing a tabular file without making database changes.
1694 Args:
1695 file_path: Path to the tabular file (CSV or Parquet)
1696 job: CrumpJob configuration
1697 db_connection_string: Database connection string
1698 filename_values: Optional dict of values extracted from filename
1700 Returns:
1701 DryRunSummary with details of what would be changed
1702 """
1703 with DatabaseConnection(db_connection_string) as db:
1704 return db.sync_tabular_file_dry_run(file_path, job, filename_values)
1707# Backward compatibility aliases