feat: add cloud AI inference support (Gemini/OpenAI-compatible)
CPU-based Ollama inference on Netcup is too slow due to server memory pressure. Add OpenAI-compatible API support so we can use Gemini Flash or other cloud APIs for clip analysis. Also increase transcript sample size to 20K chars since cloud APIs handle it easily. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d480c635ff
commit
362fe1e860
|
|
@ -12,9 +12,13 @@ class Settings(BaseSettings):
|
|||
whisper_api_url: str = "https://whisper.jeffemmett.com"
|
||||
whisper_model: str = "deepdml/faster-whisper-large-v3-turbo-ct2"
|
||||
|
||||
# Ollama
|
||||
# AI Analysis - supports "ollama" or "openai" (OpenAI-compatible API)
|
||||
ai_provider: str = "ollama" # "ollama" or "openai"
|
||||
ollama_url: str = "http://host.docker.internal:11434"
|
||||
ollama_model: str = "llama3.1:8b"
|
||||
openai_api_url: str = "" # e.g. https://api.runpod.ai/v2/{endpoint_id}/openai/v1
|
||||
openai_api_key: str = ""
|
||||
openai_model: str = "" # e.g. meta-llama/Llama-3.1-8B-Instruct
|
||||
|
||||
# Storage
|
||||
media_dir: str = "/data/media"
|
||||
|
|
|
|||
|
|
@ -10,6 +10,11 @@ from app.config import settings
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Max transcript chars to send to the LLM.
|
||||
# Cloud APIs (Gemini, OpenRouter) handle 20K+ easily.
|
||||
# CPU inference should use smaller values (~6K).
|
||||
MAX_TRANSCRIPT_CHARS = 20000
|
||||
|
||||
SYSTEM_PROMPT = """You are a viral video clip analyst. Given a video transcript with timestamps, identify the best short clips that would perform well on social media (TikTok, YouTube Shorts, Instagram Reels).
|
||||
|
||||
For each clip, provide:
|
||||
|
|
@ -42,6 +47,85 @@ Respond ONLY with valid JSON in this exact format:
|
|||
}}"""
|
||||
|
||||
|
||||
def _sample_segments(segments: list[dict], max_chars: int) -> list[dict]:
|
||||
"""Sample segments evenly across the video to fit within max_chars.
|
||||
|
||||
Takes segments from the beginning, middle, and end so clips can be
|
||||
identified from all parts of the video, not just the first N minutes.
|
||||
"""
|
||||
# Build all lines first to measure sizes
|
||||
lines = []
|
||||
for s in segments:
|
||||
line = (
|
||||
f"[{_fmt_time(s.get('start', 0))} - {_fmt_time(s.get('end', 0))}] "
|
||||
f"{s.get('text', '').strip()}"
|
||||
)
|
||||
lines.append((s, line))
|
||||
|
||||
total_chars = sum(len(l) + 1 for _, l in lines) # +1 for newline
|
||||
|
||||
if total_chars <= max_chars:
|
||||
return segments
|
||||
|
||||
# Need to sample. Keep every Nth segment to stay under the limit.
|
||||
# Always include first and last few segments for context.
|
||||
keep_chars = 0
|
||||
kept = []
|
||||
|
||||
# Reserve ~15% for start (hook detection), ~15% for end, ~70% sampled middle
|
||||
start_budget = int(max_chars * 0.15)
|
||||
end_budget = int(max_chars * 0.15)
|
||||
middle_budget = max_chars - start_budget - end_budget
|
||||
|
||||
# Take segments from the start
|
||||
for s, line in lines:
|
||||
if keep_chars + len(line) + 1 > start_budget:
|
||||
break
|
||||
kept.append(s)
|
||||
keep_chars += len(line) + 1
|
||||
|
||||
start_count = len(kept)
|
||||
|
||||
# Take segments from the end (reversed)
|
||||
end_segs = []
|
||||
end_chars = 0
|
||||
for s, line in reversed(lines[start_count:]):
|
||||
if end_chars + len(line) + 1 > end_budget:
|
||||
break
|
||||
end_segs.append(s)
|
||||
end_chars += len(line) + 1
|
||||
|
||||
end_segs.reverse()
|
||||
end_start_idx = len(lines) - len(end_segs)
|
||||
|
||||
# Sample evenly from the middle
|
||||
middle_lines = lines[start_count:end_start_idx]
|
||||
if middle_lines:
|
||||
# Calculate skip interval to fit middle_budget
|
||||
middle_chars = 0
|
||||
middle_kept = []
|
||||
avg_line_len = sum(len(l) + 1 for _, l in middle_lines) / len(middle_lines)
|
||||
target_lines = int(middle_budget / avg_line_len)
|
||||
step = max(1, len(middle_lines) // max(1, target_lines))
|
||||
|
||||
for i in range(0, len(middle_lines), step):
|
||||
s, line = middle_lines[i]
|
||||
if middle_chars + len(line) + 1 > middle_budget:
|
||||
break
|
||||
middle_kept.append(s)
|
||||
middle_chars += len(line) + 1
|
||||
|
||||
kept.extend(middle_kept)
|
||||
|
||||
kept.extend(end_segs)
|
||||
|
||||
logger.info(
|
||||
f"Sampled {len(kept)}/{len(segments)} segments "
|
||||
f"({total_chars} -> ~{max_chars} chars) for LLM analysis"
|
||||
)
|
||||
return kept
|
||||
|
||||
|
||||
async def analyze_transcript(
|
||||
transcript: dict,
|
||||
video_title: str = "",
|
||||
|
|
@ -62,14 +146,15 @@ async def analyze_transcript(
|
|||
segments = transcript.get("segments", [])
|
||||
|
||||
if segments:
|
||||
sampled = _sample_segments(segments, MAX_TRANSCRIPT_CHARS)
|
||||
timestamped = "\n".join(
|
||||
f"[{_fmt_time(s.get('start', 0))} - {_fmt_time(s.get('end', 0))}] "
|
||||
f"{s.get('text', '').strip()}"
|
||||
for s in segments
|
||||
for s in sampled
|
||||
)
|
||||
else:
|
||||
# Fall back to plain text with rough time estimates
|
||||
timestamped = text
|
||||
# Fall back to plain text, truncated
|
||||
timestamped = text[:MAX_TRANSCRIPT_CHARS]
|
||||
|
||||
system = SYSTEM_PROMPT.format(
|
||||
min_dur=settings.clip_min_duration,
|
||||
|
|
@ -85,10 +170,30 @@ Transcript:
|
|||
|
||||
Identify the {settings.target_clips} best viral clips from this transcript."""
|
||||
|
||||
logger.info(f"Sending transcript to Ollama ({settings.ollama_model})...")
|
||||
if settings.ai_provider == "openai" and settings.openai_api_url:
|
||||
content = await _call_openai(system, user_prompt)
|
||||
else:
|
||||
content = await _call_ollama(system, user_prompt)
|
||||
|
||||
async with httpx.AsyncClient(timeout=1800.0) as client:
|
||||
response = await client.post(
|
||||
clips = _parse_clips(content, video_duration)
|
||||
|
||||
logger.info(f"AI identified {len(clips)} clips")
|
||||
return clips
|
||||
|
||||
|
||||
async def _call_ollama(system: str, user_prompt: str) -> str:
|
||||
"""Call Ollama local API using streaming to avoid read timeouts."""
|
||||
logger.info(
|
||||
f"Sending to Ollama ({settings.ollama_model}), "
|
||||
f"prompt size: {len(user_prompt)} chars..."
|
||||
)
|
||||
# Use streaming to prevent httpx ReadTimeout on slow CPU inference.
|
||||
# Each streamed chunk resets the read timeout.
|
||||
timeout = httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0)
|
||||
content_parts = []
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{settings.ollama_url}/api/chat",
|
||||
json={
|
||||
"model": settings.ollama_model,
|
||||
|
|
@ -96,21 +201,60 @@ Identify the {settings.target_clips} best viral clips from this transcript."""
|
|||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"stream": False,
|
||||
"stream": True,
|
||||
"options": {
|
||||
"temperature": 0.3,
|
||||
"num_predict": 4096,
|
||||
"num_predict": 2048,
|
||||
"num_ctx": 4096,
|
||||
},
|
||||
},
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
msg = chunk.get("message", {}).get("content", "")
|
||||
if msg:
|
||||
content_parts.append(msg)
|
||||
if chunk.get("done"):
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
content = "".join(content_parts)
|
||||
logger.info(f"Ollama response: {len(content)} chars")
|
||||
return content
|
||||
|
||||
|
||||
async def _call_openai(system: str, user_prompt: str) -> str:
|
||||
"""Call OpenAI-compatible API (RunPod vLLM, OpenRouter, etc.)."""
|
||||
logger.info(
|
||||
f"Sending to OpenAI API ({settings.openai_model}), "
|
||||
f"prompt size: {len(user_prompt)} chars..."
|
||||
)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if settings.openai_api_key:
|
||||
headers["Authorization"] = f"Bearer {settings.openai_api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.openai_api_url}/chat/completions",
|
||||
headers=headers,
|
||||
json={
|
||||
"model": settings.openai_model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 2048,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
content = result.get("message", {}).get("content", "")
|
||||
clips = _parse_clips(content, video_duration)
|
||||
|
||||
logger.info(f"AI identified {len(clips)} clips")
|
||||
return clips
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
def _parse_clips(content: str, video_duration: float) -> list[dict]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue