clip-forge/backend/app/services/ai_analysis.py

314 lines
10 KiB
Python

"""AI clip analysis using Ollama (local LLM)."""
import json
import logging
import re
import httpx
from app.config import settings
logger = logging.getLogger(__name__)
# Max transcript chars to send to the LLM.
# Cloud APIs (Gemini, OpenRouter) handle 20K+ easily.
# CPU inference should use smaller values (~6K).
MAX_TRANSCRIPT_CHARS = 20000
SYSTEM_PROMPT = """You are a viral video clip analyst. Given a video transcript with timestamps, identify the best short clips that would perform well on social media (TikTok, YouTube Shorts, Instagram Reels).
For each clip, provide:
- A catchy title (max 60 chars)
- Start and end timestamps (in seconds)
- Virality score (0-100)
- Category (one of: hook, story, insight, humor, emotional, controversial, educational)
- Brief reasoning for why this clip would go viral
Rules:
- Clips should be {min_dur}-{max_dur} seconds long
- Identify {target} clips, ranked by virality potential
- Clips should start and end at natural sentence boundaries
- Prefer clips with strong hooks in the first 3 seconds
- Look for emotional peaks, surprising statements, quotable moments
- Avoid clips that start mid-sentence or end abruptly
Respond ONLY with valid JSON in this exact format:
{{
"clips": [
{{
"title": "Clip title here",
"start_time": 12.5,
"end_time": 45.2,
"virality_score": 85,
"category": "hook",
"reasoning": "Why this clip would perform well"
}}
]
}}"""
def _sample_segments(segments: list[dict], max_chars: int) -> list[dict]:
"""Sample segments evenly across the video to fit within max_chars.
Takes segments from the beginning, middle, and end so clips can be
identified from all parts of the video, not just the first N minutes.
"""
# Build all lines first to measure sizes
lines = []
for s in segments:
line = (
f"[{_fmt_time(s.get('start', 0))} - {_fmt_time(s.get('end', 0))}] "
f"{s.get('text', '').strip()}"
)
lines.append((s, line))
total_chars = sum(len(l) + 1 for _, l in lines) # +1 for newline
if total_chars <= max_chars:
return segments
# Need to sample. Keep every Nth segment to stay under the limit.
# Always include first and last few segments for context.
keep_chars = 0
kept = []
# Reserve ~15% for start (hook detection), ~15% for end, ~70% sampled middle
start_budget = int(max_chars * 0.15)
end_budget = int(max_chars * 0.15)
middle_budget = max_chars - start_budget - end_budget
# Take segments from the start
for s, line in lines:
if keep_chars + len(line) + 1 > start_budget:
break
kept.append(s)
keep_chars += len(line) + 1
start_count = len(kept)
# Take segments from the end (reversed)
end_segs = []
end_chars = 0
for s, line in reversed(lines[start_count:]):
if end_chars + len(line) + 1 > end_budget:
break
end_segs.append(s)
end_chars += len(line) + 1
end_segs.reverse()
end_start_idx = len(lines) - len(end_segs)
# Sample evenly from the middle
middle_lines = lines[start_count:end_start_idx]
if middle_lines:
# Calculate skip interval to fit middle_budget
middle_chars = 0
middle_kept = []
avg_line_len = sum(len(l) + 1 for _, l in middle_lines) / len(middle_lines)
target_lines = int(middle_budget / avg_line_len)
step = max(1, len(middle_lines) // max(1, target_lines))
for i in range(0, len(middle_lines), step):
s, line = middle_lines[i]
if middle_chars + len(line) + 1 > middle_budget:
break
middle_kept.append(s)
middle_chars += len(line) + 1
kept.extend(middle_kept)
kept.extend(end_segs)
logger.info(
f"Sampled {len(kept)}/{len(segments)} segments "
f"({total_chars} -> ~{max_chars} chars) for LLM analysis"
)
return kept
async def analyze_transcript(
transcript: dict,
video_title: str = "",
video_duration: float = 0,
) -> list[dict]:
"""Use Ollama to identify the best clips from a transcript.
Args:
transcript: dict with 'text', 'words', 'segments' from transcription service
video_title: original video title for context
video_duration: total video duration in seconds
Returns:
List of clip dicts with title, start_time, end_time, virality_score, category, reasoning
"""
# Build timestamped transcript for the LLM
text = transcript.get("text", "")
segments = transcript.get("segments", [])
if segments:
sampled = _sample_segments(segments, MAX_TRANSCRIPT_CHARS)
timestamped = "\n".join(
f"[{_fmt_time(s.get('start', 0))} - {_fmt_time(s.get('end', 0))}] "
f"{s.get('text', '').strip()}"
for s in sampled
)
else:
# Fall back to plain text, truncated
timestamped = text[:MAX_TRANSCRIPT_CHARS]
system = SYSTEM_PROMPT.format(
min_dur=settings.clip_min_duration,
max_dur=settings.clip_max_duration,
target=settings.target_clips,
)
user_prompt = f"""Video Title: {video_title}
Video Duration: {_fmt_time(video_duration)}
Transcript:
{timestamped}
Identify the {settings.target_clips} best viral clips from this transcript."""
if settings.ai_provider == "openai" and settings.openai_api_url:
content = await _call_openai(system, user_prompt)
else:
content = await _call_ollama(system, user_prompt)
clips = _parse_clips(content, video_duration)
logger.info(f"AI identified {len(clips)} clips")
return clips
async def _call_ollama(system: str, user_prompt: str) -> str:
"""Call Ollama local API using streaming to avoid read timeouts."""
logger.info(
f"Sending to Ollama ({settings.ollama_model}), "
f"prompt size: {len(user_prompt)} chars..."
)
# Use streaming to prevent httpx ReadTimeout on slow CPU inference.
# Each streamed chunk resets the read timeout.
timeout = httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0)
content_parts = []
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
"POST",
f"{settings.ollama_url}/api/chat",
json={
"model": settings.ollama_model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user_prompt},
],
"stream": True,
"options": {
"temperature": 0.3,
"num_predict": 2048,
"num_ctx": 4096,
},
},
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line.strip():
continue
try:
chunk = json.loads(line)
msg = chunk.get("message", {}).get("content", "")
if msg:
content_parts.append(msg)
if chunk.get("done"):
break
except json.JSONDecodeError:
continue
content = "".join(content_parts)
logger.info(f"Ollama response: {len(content)} chars")
return content
async def _call_openai(system: str, user_prompt: str) -> str:
"""Call OpenAI-compatible API (RunPod vLLM, OpenRouter, etc.)."""
logger.info(
f"Sending to OpenAI API ({settings.openai_model}), "
f"prompt size: {len(user_prompt)} chars..."
)
headers = {"Content-Type": "application/json"}
if settings.openai_api_key:
headers["Authorization"] = f"Bearer {settings.openai_api_key}"
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
f"{settings.openai_api_url}/chat/completions",
headers=headers,
json={
"model": settings.openai_model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user_prompt},
],
"temperature": 0.3,
"max_tokens": 2048,
},
)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"]
def _parse_clips(content: str, video_duration: float) -> list[dict]:
"""Parse LLM response into clip list, handling imperfect JSON."""
# Try to extract JSON from response
json_match = re.search(r"\{[\s\S]*\}", content)
if not json_match:
logger.error(f"No JSON found in LLM response: {content[:200]}")
return []
try:
data = json.loads(json_match.group())
except json.JSONDecodeError:
# Try to fix common JSON issues
fixed = json_match.group()
fixed = re.sub(r",\s*}", "}", fixed)
fixed = re.sub(r",\s*]", "]", fixed)
try:
data = json.loads(fixed)
except json.JSONDecodeError:
logger.error(f"Failed to parse LLM JSON: {content[:200]}")
return []
raw_clips = data.get("clips", [])
clips = []
for c in raw_clips:
start = float(c.get("start_time", 0))
end = float(c.get("end_time", 0))
# Validate
if end <= start:
continue
if start < 0:
start = 0
if end > video_duration and video_duration > 0:
end = video_duration
clips.append({
"title": str(c.get("title", "Untitled"))[:100],
"start_time": round(start, 2),
"end_time": round(end, 2),
"virality_score": max(0, min(100, float(c.get("virality_score", 50)))),
"category": str(c.get("category", "general")),
"reasoning": str(c.get("reasoning", "")),
})
# Sort by virality score descending
clips.sort(key=lambda x: x["virality_score"], reverse=True)
return clips
def _fmt_time(seconds: float) -> str:
"""Format seconds as MM:SS."""
m, s = divmod(int(seconds), 60)
return f"{m:02d}:{s:02d}"