314 lines
10 KiB
Python
314 lines
10 KiB
Python
"""AI clip analysis using Ollama (local LLM)."""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
|
|
import httpx
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Max transcript chars to send to the LLM.
|
|
# Cloud APIs (Gemini, OpenRouter) handle 20K+ easily.
|
|
# CPU inference should use smaller values (~6K).
|
|
MAX_TRANSCRIPT_CHARS = 20000
|
|
|
|
SYSTEM_PROMPT = """You are a viral video clip analyst. Given a video transcript with timestamps, identify the best short clips that would perform well on social media (TikTok, YouTube Shorts, Instagram Reels).
|
|
|
|
For each clip, provide:
|
|
- A catchy title (max 60 chars)
|
|
- Start and end timestamps (in seconds)
|
|
- Virality score (0-100)
|
|
- Category (one of: hook, story, insight, humor, emotional, controversial, educational)
|
|
- Brief reasoning for why this clip would go viral
|
|
|
|
Rules:
|
|
- Clips should be {min_dur}-{max_dur} seconds long
|
|
- Identify {target} clips, ranked by virality potential
|
|
- Clips should start and end at natural sentence boundaries
|
|
- Prefer clips with strong hooks in the first 3 seconds
|
|
- Look for emotional peaks, surprising statements, quotable moments
|
|
- Avoid clips that start mid-sentence or end abruptly
|
|
|
|
Respond ONLY with valid JSON in this exact format:
|
|
{{
|
|
"clips": [
|
|
{{
|
|
"title": "Clip title here",
|
|
"start_time": 12.5,
|
|
"end_time": 45.2,
|
|
"virality_score": 85,
|
|
"category": "hook",
|
|
"reasoning": "Why this clip would perform well"
|
|
}}
|
|
]
|
|
}}"""
|
|
|
|
|
|
def _sample_segments(segments: list[dict], max_chars: int) -> list[dict]:
|
|
"""Sample segments evenly across the video to fit within max_chars.
|
|
|
|
Takes segments from the beginning, middle, and end so clips can be
|
|
identified from all parts of the video, not just the first N minutes.
|
|
"""
|
|
# Build all lines first to measure sizes
|
|
lines = []
|
|
for s in segments:
|
|
line = (
|
|
f"[{_fmt_time(s.get('start', 0))} - {_fmt_time(s.get('end', 0))}] "
|
|
f"{s.get('text', '').strip()}"
|
|
)
|
|
lines.append((s, line))
|
|
|
|
total_chars = sum(len(l) + 1 for _, l in lines) # +1 for newline
|
|
|
|
if total_chars <= max_chars:
|
|
return segments
|
|
|
|
# Need to sample. Keep every Nth segment to stay under the limit.
|
|
# Always include first and last few segments for context.
|
|
keep_chars = 0
|
|
kept = []
|
|
|
|
# Reserve ~15% for start (hook detection), ~15% for end, ~70% sampled middle
|
|
start_budget = int(max_chars * 0.15)
|
|
end_budget = int(max_chars * 0.15)
|
|
middle_budget = max_chars - start_budget - end_budget
|
|
|
|
# Take segments from the start
|
|
for s, line in lines:
|
|
if keep_chars + len(line) + 1 > start_budget:
|
|
break
|
|
kept.append(s)
|
|
keep_chars += len(line) + 1
|
|
|
|
start_count = len(kept)
|
|
|
|
# Take segments from the end (reversed)
|
|
end_segs = []
|
|
end_chars = 0
|
|
for s, line in reversed(lines[start_count:]):
|
|
if end_chars + len(line) + 1 > end_budget:
|
|
break
|
|
end_segs.append(s)
|
|
end_chars += len(line) + 1
|
|
|
|
end_segs.reverse()
|
|
end_start_idx = len(lines) - len(end_segs)
|
|
|
|
# Sample evenly from the middle
|
|
middle_lines = lines[start_count:end_start_idx]
|
|
if middle_lines:
|
|
# Calculate skip interval to fit middle_budget
|
|
middle_chars = 0
|
|
middle_kept = []
|
|
avg_line_len = sum(len(l) + 1 for _, l in middle_lines) / len(middle_lines)
|
|
target_lines = int(middle_budget / avg_line_len)
|
|
step = max(1, len(middle_lines) // max(1, target_lines))
|
|
|
|
for i in range(0, len(middle_lines), step):
|
|
s, line = middle_lines[i]
|
|
if middle_chars + len(line) + 1 > middle_budget:
|
|
break
|
|
middle_kept.append(s)
|
|
middle_chars += len(line) + 1
|
|
|
|
kept.extend(middle_kept)
|
|
|
|
kept.extend(end_segs)
|
|
|
|
logger.info(
|
|
f"Sampled {len(kept)}/{len(segments)} segments "
|
|
f"({total_chars} -> ~{max_chars} chars) for LLM analysis"
|
|
)
|
|
return kept
|
|
|
|
|
|
async def analyze_transcript(
|
|
transcript: dict,
|
|
video_title: str = "",
|
|
video_duration: float = 0,
|
|
) -> list[dict]:
|
|
"""Use Ollama to identify the best clips from a transcript.
|
|
|
|
Args:
|
|
transcript: dict with 'text', 'words', 'segments' from transcription service
|
|
video_title: original video title for context
|
|
video_duration: total video duration in seconds
|
|
|
|
Returns:
|
|
List of clip dicts with title, start_time, end_time, virality_score, category, reasoning
|
|
"""
|
|
# Build timestamped transcript for the LLM
|
|
text = transcript.get("text", "")
|
|
segments = transcript.get("segments", [])
|
|
|
|
if segments:
|
|
sampled = _sample_segments(segments, MAX_TRANSCRIPT_CHARS)
|
|
timestamped = "\n".join(
|
|
f"[{_fmt_time(s.get('start', 0))} - {_fmt_time(s.get('end', 0))}] "
|
|
f"{s.get('text', '').strip()}"
|
|
for s in sampled
|
|
)
|
|
else:
|
|
# Fall back to plain text, truncated
|
|
timestamped = text[:MAX_TRANSCRIPT_CHARS]
|
|
|
|
system = SYSTEM_PROMPT.format(
|
|
min_dur=settings.clip_min_duration,
|
|
max_dur=settings.clip_max_duration,
|
|
target=settings.target_clips,
|
|
)
|
|
|
|
user_prompt = f"""Video Title: {video_title}
|
|
Video Duration: {_fmt_time(video_duration)}
|
|
|
|
Transcript:
|
|
{timestamped}
|
|
|
|
Identify the {settings.target_clips} best viral clips from this transcript."""
|
|
|
|
if settings.ai_provider == "openai" and settings.openai_api_url:
|
|
content = await _call_openai(system, user_prompt)
|
|
else:
|
|
content = await _call_ollama(system, user_prompt)
|
|
|
|
clips = _parse_clips(content, video_duration)
|
|
|
|
logger.info(f"AI identified {len(clips)} clips")
|
|
return clips
|
|
|
|
|
|
async def _call_ollama(system: str, user_prompt: str) -> str:
|
|
"""Call Ollama local API using streaming to avoid read timeouts."""
|
|
logger.info(
|
|
f"Sending to Ollama ({settings.ollama_model}), "
|
|
f"prompt size: {len(user_prompt)} chars..."
|
|
)
|
|
# Use streaming to prevent httpx ReadTimeout on slow CPU inference.
|
|
# Each streamed chunk resets the read timeout.
|
|
timeout = httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0)
|
|
content_parts = []
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{settings.ollama_url}/api/chat",
|
|
json={
|
|
"model": settings.ollama_model,
|
|
"messages": [
|
|
{"role": "system", "content": system},
|
|
{"role": "user", "content": user_prompt},
|
|
],
|
|
"stream": True,
|
|
"options": {
|
|
"temperature": 0.3,
|
|
"num_predict": 2048,
|
|
"num_ctx": 4096,
|
|
},
|
|
},
|
|
) as response:
|
|
response.raise_for_status()
|
|
async for line in response.aiter_lines():
|
|
if not line.strip():
|
|
continue
|
|
try:
|
|
chunk = json.loads(line)
|
|
msg = chunk.get("message", {}).get("content", "")
|
|
if msg:
|
|
content_parts.append(msg)
|
|
if chunk.get("done"):
|
|
break
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
content = "".join(content_parts)
|
|
logger.info(f"Ollama response: {len(content)} chars")
|
|
return content
|
|
|
|
|
|
async def _call_openai(system: str, user_prompt: str) -> str:
|
|
"""Call OpenAI-compatible API (RunPod vLLM, OpenRouter, etc.)."""
|
|
logger.info(
|
|
f"Sending to OpenAI API ({settings.openai_model}), "
|
|
f"prompt size: {len(user_prompt)} chars..."
|
|
)
|
|
headers = {"Content-Type": "application/json"}
|
|
if settings.openai_api_key:
|
|
headers["Authorization"] = f"Bearer {settings.openai_api_key}"
|
|
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
response = await client.post(
|
|
f"{settings.openai_api_url}/chat/completions",
|
|
headers=headers,
|
|
json={
|
|
"model": settings.openai_model,
|
|
"messages": [
|
|
{"role": "system", "content": system},
|
|
{"role": "user", "content": user_prompt},
|
|
],
|
|
"temperature": 0.3,
|
|
"max_tokens": 2048,
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
return result["choices"][0]["message"]["content"]
|
|
|
|
|
|
def _parse_clips(content: str, video_duration: float) -> list[dict]:
|
|
"""Parse LLM response into clip list, handling imperfect JSON."""
|
|
# Try to extract JSON from response
|
|
json_match = re.search(r"\{[\s\S]*\}", content)
|
|
if not json_match:
|
|
logger.error(f"No JSON found in LLM response: {content[:200]}")
|
|
return []
|
|
|
|
try:
|
|
data = json.loads(json_match.group())
|
|
except json.JSONDecodeError:
|
|
# Try to fix common JSON issues
|
|
fixed = json_match.group()
|
|
fixed = re.sub(r",\s*}", "}", fixed)
|
|
fixed = re.sub(r",\s*]", "]", fixed)
|
|
try:
|
|
data = json.loads(fixed)
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Failed to parse LLM JSON: {content[:200]}")
|
|
return []
|
|
|
|
raw_clips = data.get("clips", [])
|
|
clips = []
|
|
|
|
for c in raw_clips:
|
|
start = float(c.get("start_time", 0))
|
|
end = float(c.get("end_time", 0))
|
|
|
|
# Validate
|
|
if end <= start:
|
|
continue
|
|
if start < 0:
|
|
start = 0
|
|
if end > video_duration and video_duration > 0:
|
|
end = video_duration
|
|
|
|
clips.append({
|
|
"title": str(c.get("title", "Untitled"))[:100],
|
|
"start_time": round(start, 2),
|
|
"end_time": round(end, 2),
|
|
"virality_score": max(0, min(100, float(c.get("virality_score", 50)))),
|
|
"category": str(c.get("category", "general")),
|
|
"reasoning": str(c.get("reasoning", "")),
|
|
})
|
|
|
|
# Sort by virality score descending
|
|
clips.sort(key=lambda x: x["virality_score"], reverse=True)
|
|
return clips
|
|
|
|
|
|
def _fmt_time(seconds: float) -> str:
|
|
"""Format seconds as MM:SS."""
|
|
m, s = divmod(int(seconds), 60)
|
|
return f"{m:02d}:{s:02d}"
|