feat: add cloud AI inference support (Gemini/OpenAI-compatible)

CPU-based Ollama inference on Netcup is too slow due to server memory
pressure. Add OpenAI-compatible API support so we can use Gemini Flash
or other cloud APIs for clip analysis. Also increase transcript sample
size to 20K chars since cloud APIs handle it easily.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jeff Emmett 2026-02-10 00:44:13 +00:00
parent d480c635ff
commit 362fe1e860
2 changed files with 163 additions and 15 deletions

View File

@ -12,9 +12,13 @@ class Settings(BaseSettings):
whisper_api_url: str = "https://whisper.jeffemmett.com"
whisper_model: str = "deepdml/faster-whisper-large-v3-turbo-ct2"
# Ollama
# AI Analysis - supports "ollama" or "openai" (OpenAI-compatible API)
ai_provider: str = "ollama" # "ollama" or "openai"
ollama_url: str = "http://host.docker.internal:11434"
ollama_model: str = "llama3.1:8b"
openai_api_url: str = "" # e.g. https://api.runpod.ai/v2/{endpoint_id}/openai/v1
openai_api_key: str = ""
openai_model: str = "" # e.g. meta-llama/Llama-3.1-8B-Instruct
# Storage
media_dir: str = "/data/media"

View File

@ -10,6 +10,11 @@ from app.config import settings
logger = logging.getLogger(__name__)
# Max transcript chars to send to the LLM.
# Cloud APIs (Gemini, OpenRouter) handle 20K+ easily.
# CPU inference should use smaller values (~6K).
MAX_TRANSCRIPT_CHARS = 20000
SYSTEM_PROMPT = """You are a viral video clip analyst. Given a video transcript with timestamps, identify the best short clips that would perform well on social media (TikTok, YouTube Shorts, Instagram Reels).
For each clip, provide:
@ -42,6 +47,85 @@ Respond ONLY with valid JSON in this exact format:
}}"""
def _sample_segments(segments: list[dict], max_chars: int) -> list[dict]:
"""Sample segments evenly across the video to fit within max_chars.
Takes segments from the beginning, middle, and end so clips can be
identified from all parts of the video, not just the first N minutes.
"""
# Build all lines first to measure sizes
lines = []
for s in segments:
line = (
f"[{_fmt_time(s.get('start', 0))} - {_fmt_time(s.get('end', 0))}] "
f"{s.get('text', '').strip()}"
)
lines.append((s, line))
total_chars = sum(len(l) + 1 for _, l in lines) # +1 for newline
if total_chars <= max_chars:
return segments
# Need to sample. Keep every Nth segment to stay under the limit.
# Always include first and last few segments for context.
keep_chars = 0
kept = []
# Reserve ~15% for start (hook detection), ~15% for end, ~70% sampled middle
start_budget = int(max_chars * 0.15)
end_budget = int(max_chars * 0.15)
middle_budget = max_chars - start_budget - end_budget
# Take segments from the start
for s, line in lines:
if keep_chars + len(line) + 1 > start_budget:
break
kept.append(s)
keep_chars += len(line) + 1
start_count = len(kept)
# Take segments from the end (reversed)
end_segs = []
end_chars = 0
for s, line in reversed(lines[start_count:]):
if end_chars + len(line) + 1 > end_budget:
break
end_segs.append(s)
end_chars += len(line) + 1
end_segs.reverse()
end_start_idx = len(lines) - len(end_segs)
# Sample evenly from the middle
middle_lines = lines[start_count:end_start_idx]
if middle_lines:
# Calculate skip interval to fit middle_budget
middle_chars = 0
middle_kept = []
avg_line_len = sum(len(l) + 1 for _, l in middle_lines) / len(middle_lines)
target_lines = int(middle_budget / avg_line_len)
step = max(1, len(middle_lines) // max(1, target_lines))
for i in range(0, len(middle_lines), step):
s, line = middle_lines[i]
if middle_chars + len(line) + 1 > middle_budget:
break
middle_kept.append(s)
middle_chars += len(line) + 1
kept.extend(middle_kept)
kept.extend(end_segs)
logger.info(
f"Sampled {len(kept)}/{len(segments)} segments "
f"({total_chars} -> ~{max_chars} chars) for LLM analysis"
)
return kept
async def analyze_transcript(
transcript: dict,
video_title: str = "",
@ -62,14 +146,15 @@ async def analyze_transcript(
segments = transcript.get("segments", [])
if segments:
sampled = _sample_segments(segments, MAX_TRANSCRIPT_CHARS)
timestamped = "\n".join(
f"[{_fmt_time(s.get('start', 0))} - {_fmt_time(s.get('end', 0))}] "
f"{s.get('text', '').strip()}"
for s in segments
for s in sampled
)
else:
# Fall back to plain text with rough time estimates
timestamped = text
# Fall back to plain text, truncated
timestamped = text[:MAX_TRANSCRIPT_CHARS]
system = SYSTEM_PROMPT.format(
min_dur=settings.clip_min_duration,
@ -85,10 +170,30 @@ Transcript:
Identify the {settings.target_clips} best viral clips from this transcript."""
logger.info(f"Sending transcript to Ollama ({settings.ollama_model})...")
if settings.ai_provider == "openai" and settings.openai_api_url:
content = await _call_openai(system, user_prompt)
else:
content = await _call_ollama(system, user_prompt)
async with httpx.AsyncClient(timeout=1800.0) as client:
response = await client.post(
clips = _parse_clips(content, video_duration)
logger.info(f"AI identified {len(clips)} clips")
return clips
async def _call_ollama(system: str, user_prompt: str) -> str:
"""Call Ollama local API using streaming to avoid read timeouts."""
logger.info(
f"Sending to Ollama ({settings.ollama_model}), "
f"prompt size: {len(user_prompt)} chars..."
)
# Use streaming to prevent httpx ReadTimeout on slow CPU inference.
# Each streamed chunk resets the read timeout.
timeout = httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0)
content_parts = []
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
"POST",
f"{settings.ollama_url}/api/chat",
json={
"model": settings.ollama_model,
@ -96,21 +201,60 @@ Identify the {settings.target_clips} best viral clips from this transcript."""
{"role": "system", "content": system},
{"role": "user", "content": user_prompt},
],
"stream": False,
"stream": True,
"options": {
"temperature": 0.3,
"num_predict": 4096,
"num_predict": 2048,
"num_ctx": 4096,
},
},
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line.strip():
continue
try:
chunk = json.loads(line)
msg = chunk.get("message", {}).get("content", "")
if msg:
content_parts.append(msg)
if chunk.get("done"):
break
except json.JSONDecodeError:
continue
content = "".join(content_parts)
logger.info(f"Ollama response: {len(content)} chars")
return content
async def _call_openai(system: str, user_prompt: str) -> str:
"""Call OpenAI-compatible API (RunPod vLLM, OpenRouter, etc.)."""
logger.info(
f"Sending to OpenAI API ({settings.openai_model}), "
f"prompt size: {len(user_prompt)} chars..."
)
headers = {"Content-Type": "application/json"}
if settings.openai_api_key:
headers["Authorization"] = f"Bearer {settings.openai_api_key}"
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
f"{settings.openai_api_url}/chat/completions",
headers=headers,
json={
"model": settings.openai_model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user_prompt},
],
"temperature": 0.3,
"max_tokens": 2048,
},
)
response.raise_for_status()
result = response.json()
content = result.get("message", {}).get("content", "")
clips = _parse_clips(content, video_duration)
logger.info(f"AI identified {len(clips)} clips")
return clips
return result["choices"][0]["message"]["content"]
def _parse_clips(content: str, video_duration: float) -> list[dict]: