import asyncio import json from uuid import UUID from arq import create_pool from arq.connections import RedisSettings from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sse_starlette.sse import EventSourceResponse from app.config import settings from app.database import get_db from app.models import Job, Clip from app.schemas import JobCreate, JobResponse, ClipResponse router = APIRouter() def _redis_settings() -> RedisSettings: from urllib.parse import urlparse parsed = urlparse(settings.redis_url) return RedisSettings( host=parsed.hostname or "redis", port=parsed.port or 6379, database=int(parsed.path.lstrip("/") or "0"), ) @router.post("/jobs", response_model=JobResponse, status_code=201) async def create_job(job_in: JobCreate, db: AsyncSession = Depends(get_db)): if job_in.source_type == "youtube" and not job_in.source_url: raise HTTPException(400, "source_url required for youtube source") job = Job( source_type=job_in.source_type, source_url=job_in.source_url, status="pending", ) db.add(job) await db.commit() await db.refresh(job) # Enqueue processing pool = await create_pool(_redis_settings()) await pool.enqueue_job("process_job", str(job.id)) await pool.close() return job @router.post("/jobs/upload", response_model=JobResponse, status_code=201) async def create_job_upload( file: UploadFile = File(...), db: AsyncSession = Depends(get_db), ): import os import aiofiles os.makedirs(settings.media_dir, exist_ok=True) safe_name = file.filename.replace("/", "_").replace("..", "_") dest = os.path.join(settings.media_dir, f"upload_{safe_name}") async with aiofiles.open(dest, "wb") as f: while chunk := await file.read(1024 * 1024): await f.write(chunk) job = Job( source_type="upload", source_filename=safe_name, media_path=dest, status="pending", ) db.add(job) await db.commit() await db.refresh(job) pool = await create_pool(_redis_settings()) await pool.enqueue_job("process_job", str(job.id)) await pool.close() return job @router.get("/jobs", response_model=list[JobResponse]) async def list_jobs( limit: int = 20, offset: int = 0, db: AsyncSession = Depends(get_db), ): result = await db.execute( select(Job).order_by(Job.created_at.desc()).offset(offset).limit(limit) ) return result.scalars().all() @router.get("/jobs/{job_id}", response_model=JobResponse) async def get_job(job_id: UUID, db: AsyncSession = Depends(get_db)): job = await db.get(Job, job_id) if not job: raise HTTPException(404, "Job not found") return job @router.get("/jobs/{job_id}/clips", response_model=list[ClipResponse]) async def get_job_clips(job_id: UUID, db: AsyncSession = Depends(get_db)): job = await db.get(Job, job_id) if not job: raise HTTPException(404, "Job not found") result = await db.execute( select(Clip) .where(Clip.job_id == job_id) .order_by(Clip.virality_score.desc()) ) clips = result.scalars().all() # Compute duration manually since it's a generated column for clip in clips: clip.duration = clip.end_time - clip.start_time return clips @router.get("/jobs/{job_id}/progress") async def job_progress_sse(job_id: UUID, db: AsyncSession = Depends(get_db)): job = await db.get(Job, job_id) if not job: raise HTTPException(404, "Job not found") async def event_stream(): import redis.asyncio as aioredis r = aioredis.from_url(settings.redis_url) pubsub = r.pubsub() await pubsub.subscribe(f"job:{job_id}:progress") # Send current state immediately await db.refresh(job) yield { "event": "progress", "data": json.dumps({ "status": job.status, "progress": job.progress, "stage_message": job.stage_message, }), } if job.status in ("complete", "failed"): await pubsub.unsubscribe() await r.close() return try: while True: msg = await pubsub.get_message( ignore_subscribe_messages=True, timeout=1.0 ) if msg and msg["type"] == "message": data = json.loads(msg["data"]) yield {"event": "progress", "data": json.dumps(data)} if data.get("status") in ("complete", "failed"): break await asyncio.sleep(0.5) finally: await pubsub.unsubscribe() await r.close() return EventSourceResponse(event_stream())