clip-forge/backend/app/api/routes/jobs.py

168 lines
4.8 KiB
Python

import asyncio
import json
from uuid import UUID
from arq import create_pool
from arq.connections import RedisSettings
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sse_starlette.sse import EventSourceResponse
from app.config import settings
from app.database import get_db
from app.models import Job, Clip
from app.schemas import JobCreate, JobResponse, ClipResponse
router = APIRouter()
def _redis_settings() -> RedisSettings:
from urllib.parse import urlparse
parsed = urlparse(settings.redis_url)
return RedisSettings(
host=parsed.hostname or "redis",
port=parsed.port or 6379,
database=int(parsed.path.lstrip("/") or "0"),
)
@router.post("/jobs", response_model=JobResponse, status_code=201)
async def create_job(job_in: JobCreate, db: AsyncSession = Depends(get_db)):
if job_in.source_type == "youtube" and not job_in.source_url:
raise HTTPException(400, "source_url required for youtube source")
job = Job(
source_type=job_in.source_type,
source_url=job_in.source_url,
status="pending",
)
db.add(job)
await db.commit()
await db.refresh(job)
# Enqueue processing
pool = await create_pool(_redis_settings())
await pool.enqueue_job("process_job", str(job.id))
await pool.close()
return job
@router.post("/jobs/upload", response_model=JobResponse, status_code=201)
async def create_job_upload(
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
):
import os
import aiofiles
os.makedirs(settings.media_dir, exist_ok=True)
safe_name = file.filename.replace("/", "_").replace("..", "_")
dest = os.path.join(settings.media_dir, f"upload_{safe_name}")
async with aiofiles.open(dest, "wb") as f:
while chunk := await file.read(1024 * 1024):
await f.write(chunk)
job = Job(
source_type="upload",
source_filename=safe_name,
media_path=dest,
status="pending",
)
db.add(job)
await db.commit()
await db.refresh(job)
pool = await create_pool(_redis_settings())
await pool.enqueue_job("process_job", str(job.id))
await pool.close()
return job
@router.get("/jobs", response_model=list[JobResponse])
async def list_jobs(
limit: int = 20,
offset: int = 0,
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Job).order_by(Job.created_at.desc()).offset(offset).limit(limit)
)
return result.scalars().all()
@router.get("/jobs/{job_id}", response_model=JobResponse)
async def get_job(job_id: UUID, db: AsyncSession = Depends(get_db)):
job = await db.get(Job, job_id)
if not job:
raise HTTPException(404, "Job not found")
return job
@router.get("/jobs/{job_id}/clips", response_model=list[ClipResponse])
async def get_job_clips(job_id: UUID, db: AsyncSession = Depends(get_db)):
job = await db.get(Job, job_id)
if not job:
raise HTTPException(404, "Job not found")
result = await db.execute(
select(Clip)
.where(Clip.job_id == job_id)
.order_by(Clip.virality_score.desc())
)
clips = result.scalars().all()
# Compute duration manually since it's a generated column
for clip in clips:
clip.duration = clip.end_time - clip.start_time
return clips
@router.get("/jobs/{job_id}/progress")
async def job_progress_sse(job_id: UUID, db: AsyncSession = Depends(get_db)):
job = await db.get(Job, job_id)
if not job:
raise HTTPException(404, "Job not found")
async def event_stream():
import redis.asyncio as aioredis
r = aioredis.from_url(settings.redis_url)
pubsub = r.pubsub()
await pubsub.subscribe(f"job:{job_id}:progress")
# Send current state immediately
await db.refresh(job)
yield {
"event": "progress",
"data": json.dumps({
"status": job.status,
"progress": job.progress,
"stage_message": job.stage_message,
}),
}
if job.status in ("complete", "failed"):
await pubsub.unsubscribe()
await r.close()
return
try:
while True:
msg = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if msg and msg["type"] == "message":
data = json.loads(msg["data"])
yield {"event": "progress", "data": json.dumps(data)}
if data.get("status") in ("complete", "failed"):
break
await asyncio.sleep(0.5)
finally:
await pubsub.unsubscribe()
await r.close()
return EventSourceResponse(event_stream())