73 lines
2.0 KiB
Python
73 lines
2.0 KiB
Python
"""asyncpg connection pool for Spore Commons."""
|
|
|
|
import asyncpg
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_pool: asyncpg.Pool | None = None
|
|
|
|
|
|
async def init_pool(database_url: str) -> asyncpg.Pool:
|
|
"""Initialize the connection pool and run migrations."""
|
|
global _pool
|
|
_pool = await asyncpg.create_pool(
|
|
database_url,
|
|
min_size=2,
|
|
max_size=10,
|
|
command_timeout=30,
|
|
)
|
|
await _run_migrations(_pool)
|
|
log.info("Database pool initialized")
|
|
return _pool
|
|
|
|
|
|
async def _run_migrations(pool: asyncpg.Pool) -> None:
|
|
"""Run SQL migration files in order."""
|
|
migrations_dir = Path(__file__).parent / "migrations"
|
|
if not migrations_dir.exists():
|
|
return
|
|
|
|
async with pool.acquire() as conn:
|
|
# Create migrations tracking table
|
|
await conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS _migrations (
|
|
filename TEXT PRIMARY KEY,
|
|
applied_at TIMESTAMPTZ DEFAULT now()
|
|
)
|
|
""")
|
|
|
|
applied = {
|
|
row["filename"]
|
|
for row in await conn.fetch("SELECT filename FROM _migrations")
|
|
}
|
|
|
|
sql_files = sorted(migrations_dir.glob("*.sql"))
|
|
for sql_file in sql_files:
|
|
if sql_file.name in applied:
|
|
continue
|
|
log.info(f"Applying migration: {sql_file.name}")
|
|
sql = sql_file.read_text()
|
|
async with conn.transaction():
|
|
await conn.execute(sql)
|
|
await conn.execute(
|
|
"INSERT INTO _migrations (filename) VALUES ($1)",
|
|
sql_file.name,
|
|
)
|
|
|
|
|
|
def get_pool() -> asyncpg.Pool:
|
|
"""Get the current connection pool."""
|
|
if _pool is None:
|
|
raise RuntimeError("Database pool not initialized")
|
|
return _pool
|
|
|
|
|
|
async def close_pool() -> None:
|
|
"""Close the connection pool."""
|
|
global _pool
|
|
if _pool:
|
|
await _pool.close()
|
|
_pool = None
|