|
2 | 2 | from sqlalchemy.orm import Session
|
3 | 3 | from sqlalchemy import create_engine, text
|
4 | 4 | from sqlalchemy.orm import sessionmaker
|
| 5 | +from contextlib import closing |
| 6 | +import logging |
5 | 7 | import typer
|
6 | 8 | from fastapi import FastAPI
|
7 | 9 |
|
8 | 10 | from app.models import Database
|
9 | 11 |
|
| 12 | +# Initialize FastAPI and Typer CLI |
10 | 13 | app = FastAPI()
|
11 | 14 | cli = typer.Typer()
|
12 | 15 |
|
| 16 | +# Configure logging |
| 17 | +logging.basicConfig(level=logging.INFO) |
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
| 20 | + |
| 21 | +def create_session(db_url: str) -> Session: |
| 22 | + """Creates a new SQLAlchemy session for the provided database URL.""" |
| 23 | + engine = create_engine(db_url, pool_pre_ping=True, |
| 24 | + pool_size=10, max_overflow=20) |
| 25 | + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
| 26 | + return SessionLocal() |
| 27 | + |
| 28 | + |
| 29 | +def get_project_by_id(session: Session, project_id: int): |
| 30 | + """Fetch a project by its ID.""" |
| 31 | + return session.execute( |
| 32 | + text('SELECT * FROM project WHERE id = :id'), |
| 33 | + {'id': project_id} |
| 34 | + ).fetchone() |
| 35 | + |
| 36 | + |
| 37 | +def get_user_by_id(session: Session, user_id: str): |
| 38 | + """Fetch a user by their ID.""" |
| 39 | + return session.execute( |
| 40 | + text('SELECT * FROM "user" WHERE id = :id'), |
| 41 | + {'id': user_id} |
| 42 | + ).fetchone() |
| 43 | + |
13 | 44 |
|
14 | 45 | @cli.command()
|
15 | 46 | def update_users(db_url: str = typer.Option(None, help="Backend Database URL")):
|
16 |
| - print(f"Updating users") |
| 47 | + """Update users in the database by assigning owner_id to databases.""" |
| 48 | + logger.info("Starting user update process...") |
| 49 | + if not db_url: |
| 50 | + logger.error("No database URL provided. Exiting.") |
| 51 | + return |
| 52 | + |
17 | 53 | db: Session = next(get_db())
|
| 54 | + updated_count = 0 |
18 | 55 | try:
|
| 56 | + # Fetch databases without owners |
19 | 57 | databases = db.query(Database).filter(
|
20 |
| - Database.owner_id == None).all() |
21 |
| - print(f"Found {len(databases)} databases without owner") |
| 58 | + Database.owner_id.is_(None)).all() |
| 59 | + logger.info(f"Found {len(databases)} databases without owner.") |
22 | 60 |
|
23 |
| - if not db_url: |
24 |
| - print("No database URL provided") |
| 61 | + if not databases: |
| 62 | + logger.info("No databases to update. Exiting.") |
25 | 63 | return
|
26 |
| - engine = create_engine(db_url, pool_pre_ping=True, |
27 |
| - pool_size=10, max_overflow=20) |
28 |
| - NewSessionLocal = sessionmaker( |
29 |
| - autocommit=False, autoflush=False, bind=engine) |
30 |
| - new_db = NewSessionLocal() |
31 |
| - try: |
| 64 | + |
| 65 | + # Create a new session for the provided db_url |
| 66 | + with closing(create_session(db_url)) as new_db: |
32 | 67 | for database in databases:
|
33 |
| - project = new_db.execute(text('SELECT * FROM project WHERE id = :id'), { |
34 |
| - 'id': database.project_id}).fetchone() |
| 68 | + project = get_project_by_id(new_db, database.project_id) |
35 | 69 | if not project:
|
| 70 | + logger.warning(f"""No project found for database ID { |
| 71 | + database.id}. Skipping.""") |
36 | 72 | continue
|
37 | 73 |
|
38 |
| - user = new_db.execute(text('SELECT * FROM "user" WHERE id = :id'), { |
39 |
| - 'id': str(project.owner_id)}).fetchone() |
| 74 | + user = get_user_by_id(new_db, str(project.owner_id)) |
40 | 75 | if not user:
|
| 76 | + logger.warning(f"""No user found for project ID { |
| 77 | + project.id}. Skipping.""") |
41 | 78 | continue
|
42 | 79 |
|
| 80 | + # Update database with owner details |
43 | 81 | database.owner_id = user.id
|
44 | 82 | database.email = user.email
|
| 83 | + updated_count += 1 |
| 84 | + logger.info(f"""Updated database ID { |
| 85 | + database.id} with owner ID {user.id}.""") |
45 | 86 | db.commit()
|
46 | 87 |
|
47 |
| - finally: |
48 |
| - new_db.close() |
49 |
| - |
| 88 | + except Exception as e: |
| 89 | + logger.error(f"An error occurred: {e}", exc_info=True) |
50 | 90 | finally:
|
51 | 91 | db.close()
|
| 92 | + logger.info(f"""Database session closed. Total successful updates: { |
| 93 | + updated_count}""") |
52 | 94 |
|
53 | 95 |
|
54 | 96 | if __name__ == "__main__":
|
|
0 commit comments