168 lines
4.8 KiB
Python
168 lines
4.8 KiB
Python
import asyncio
|
|
import json
|
|
from uuid import UUID
|
|
|
|
from arq import create_pool
|
|
from arq.connections import RedisSettings
|
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sse_starlette.sse import EventSourceResponse
|
|
|
|
from app.config import settings
|
|
from app.database import get_db
|
|
from app.models import Job, Clip
|
|
from app.schemas import JobCreate, JobResponse, ClipResponse
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def _redis_settings() -> RedisSettings:
|
|
from urllib.parse import urlparse
|
|
parsed = urlparse(settings.redis_url)
|
|
return RedisSettings(
|
|
host=parsed.hostname or "redis",
|
|
port=parsed.port or 6379,
|
|
database=int(parsed.path.lstrip("/") or "0"),
|
|
)
|
|
|
|
|
|
@router.post("/jobs", response_model=JobResponse, status_code=201)
|
|
async def create_job(job_in: JobCreate, db: AsyncSession = Depends(get_db)):
|
|
if job_in.source_type == "youtube" and not job_in.source_url:
|
|
raise HTTPException(400, "source_url required for youtube source")
|
|
|
|
job = Job(
|
|
source_type=job_in.source_type,
|
|
source_url=job_in.source_url,
|
|
status="pending",
|
|
)
|
|
db.add(job)
|
|
await db.commit()
|
|
await db.refresh(job)
|
|
|
|
# Enqueue processing
|
|
pool = await create_pool(_redis_settings())
|
|
await pool.enqueue_job("process_job", str(job.id))
|
|
await pool.close()
|
|
|
|
return job
|
|
|
|
|
|
@router.post("/jobs/upload", response_model=JobResponse, status_code=201)
|
|
async def create_job_upload(
|
|
file: UploadFile = File(...),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
import os
|
|
import aiofiles
|
|
|
|
os.makedirs(settings.media_dir, exist_ok=True)
|
|
safe_name = file.filename.replace("/", "_").replace("..", "_")
|
|
dest = os.path.join(settings.media_dir, f"upload_{safe_name}")
|
|
|
|
async with aiofiles.open(dest, "wb") as f:
|
|
while chunk := await file.read(1024 * 1024):
|
|
await f.write(chunk)
|
|
|
|
job = Job(
|
|
source_type="upload",
|
|
source_filename=safe_name,
|
|
media_path=dest,
|
|
status="pending",
|
|
)
|
|
db.add(job)
|
|
await db.commit()
|
|
await db.refresh(job)
|
|
|
|
pool = await create_pool(_redis_settings())
|
|
await pool.enqueue_job("process_job", str(job.id))
|
|
await pool.close()
|
|
|
|
return job
|
|
|
|
|
|
@router.get("/jobs", response_model=list[JobResponse])
|
|
async def list_jobs(
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
result = await db.execute(
|
|
select(Job).order_by(Job.created_at.desc()).offset(offset).limit(limit)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.get("/jobs/{job_id}", response_model=JobResponse)
|
|
async def get_job(job_id: UUID, db: AsyncSession = Depends(get_db)):
|
|
job = await db.get(Job, job_id)
|
|
if not job:
|
|
raise HTTPException(404, "Job not found")
|
|
return job
|
|
|
|
|
|
@router.get("/jobs/{job_id}/clips", response_model=list[ClipResponse])
|
|
async def get_job_clips(job_id: UUID, db: AsyncSession = Depends(get_db)):
|
|
job = await db.get(Job, job_id)
|
|
if not job:
|
|
raise HTTPException(404, "Job not found")
|
|
|
|
result = await db.execute(
|
|
select(Clip)
|
|
.where(Clip.job_id == job_id)
|
|
.order_by(Clip.virality_score.desc())
|
|
)
|
|
clips = result.scalars().all()
|
|
# Compute duration manually since it's a generated column
|
|
for clip in clips:
|
|
clip.duration = clip.end_time - clip.start_time
|
|
return clips
|
|
|
|
|
|
@router.get("/jobs/{job_id}/progress")
|
|
async def job_progress_sse(job_id: UUID, db: AsyncSession = Depends(get_db)):
|
|
job = await db.get(Job, job_id)
|
|
if not job:
|
|
raise HTTPException(404, "Job not found")
|
|
|
|
async def event_stream():
|
|
import redis.asyncio as aioredis
|
|
|
|
r = aioredis.from_url(settings.redis_url)
|
|
pubsub = r.pubsub()
|
|
await pubsub.subscribe(f"job:{job_id}:progress")
|
|
|
|
# Send current state immediately
|
|
await db.refresh(job)
|
|
yield {
|
|
"event": "progress",
|
|
"data": json.dumps({
|
|
"status": job.status,
|
|
"progress": job.progress,
|
|
"stage_message": job.stage_message,
|
|
}),
|
|
}
|
|
|
|
if job.status in ("complete", "failed"):
|
|
await pubsub.unsubscribe()
|
|
await r.close()
|
|
return
|
|
|
|
try:
|
|
while True:
|
|
msg = await pubsub.get_message(
|
|
ignore_subscribe_messages=True, timeout=1.0
|
|
)
|
|
if msg and msg["type"] == "message":
|
|
data = json.loads(msg["data"])
|
|
yield {"event": "progress", "data": json.dumps(data)}
|
|
if data.get("status") in ("complete", "failed"):
|
|
break
|
|
await asyncio.sleep(0.5)
|
|
finally:
|
|
await pubsub.unsubscribe()
|
|
await r.close()
|
|
|
|
return EventSourceResponse(event_stream())
|