diff --git a/backup-service/backup.py b/backup-service/backup.py index 075a4ca..0718be0 100644 --- a/backup-service/backup.py +++ b/backup-service/backup.py @@ -79,6 +79,8 @@ def create_backup() -> tuple[str, bytes]: config.DB_NAME, "--no-owner", "--no-acl", + "--clean", # Add DROP commands before CREATE + "--if-exists", # Use IF EXISTS with DROP commands "-F", "p", # plain SQL format ] diff --git a/backup-service/restore.py b/backup-service/restore.py index 23f48fb..f945b81 100644 --- a/backup-service/restore.py +++ b/backup-service/restore.py @@ -3,8 +3,9 @@ Restore PostgreSQL database from S3 backup. Usage: - python restore.py - List available backups - python restore.py - Restore from specific backup + python restore.py - List available backups + python restore.py - Restore from backup (cleans DB first) + python restore.py --no-clean - Restore without cleaning DB first """ import gzip import os @@ -62,7 +63,48 @@ def list_backups(s3_client) -> list[tuple[str, float, str]]: return [] -def restore_backup(s3_client, filename: str) -> None: +def clean_database() -> None: + """Drop and recreate public schema to clean the database.""" + print("Cleaning database (dropping and recreating public schema)...") + + env = os.environ.copy() + env["PGPASSWORD"] = config.DB_PASSWORD + + # Drop and recreate public schema + clean_sql = b""" +DROP SCHEMA public CASCADE; +CREATE SCHEMA public; +GRANT ALL ON SCHEMA public TO public; +""" + + cmd = [ + "psql", + "-h", + config.DB_HOST, + "-p", + config.DB_PORT, + "-U", + config.DB_USER, + "-d", + config.DB_NAME, + ] + + result = subprocess.run( + cmd, + env=env, + input=clean_sql, + capture_output=True, + ) + + if result.returncode != 0: + stderr = result.stderr.decode() + if "ERROR" in stderr: + raise Exception(f"Database cleanup failed: {stderr}") + + print("Database cleaned successfully!") + + +def restore_backup(s3_client, filename: str, clean_first: bool = True) -> None: """Download and restore backup.""" key = f"{config.S3_BACKUP_PREFIX}{filename}" @@ -79,6 +121,10 @@ def restore_backup(s3_client, filename: str) -> None: print("Decompressing...") sql_data = gzip.decompress(compressed_data) + # Clean database before restore if requested + if clean_first: + clean_database() + print(f"Restoring to database {config.DB_NAME}...") # Build psql command @@ -124,20 +170,32 @@ def main() -> int: s3_client = create_s3_client() - if len(sys.argv) < 2: + # Parse arguments + args = sys.argv[1:] + clean_first = True + + if "--no-clean" in args: + clean_first = False + args.remove("--no-clean") + + if len(args) < 1: # List available backups backups = list_backups(s3_client) if backups: print(f"\nTo restore, run: python restore.py ") + print("Add --no-clean to skip database cleanup before restore") else: print("No backups found.") return 0 - filename = sys.argv[1] + filename = args[0] # Confirm restore print(f"WARNING: This will restore database from {filename}") - print("This may overwrite existing data!") + if clean_first: + print("Database will be CLEANED (all existing data will be DELETED)!") + else: + print("Database will NOT be cleaned (may cause conflicts with existing data)") print() confirm = input("Type 'yes' to continue: ") @@ -147,7 +205,7 @@ def main() -> int: return 0 try: - restore_backup(s3_client, filename) + restore_backup(s3_client, filename, clean_first=clean_first) return 0 except Exception as e: print(f"Restore failed: {e}")