clip-forge/backend/app/services/ai_analysis.py

170 lines
5.2 KiB
Python

"""AI clip analysis using Ollama (local LLM)."""
import json
import logging
import re
import httpx
from app.config import settings
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are a viral video clip analyst. Given a video transcript with timestamps, identify the best short clips that would perform well on social media (TikTok, YouTube Shorts, Instagram Reels).
For each clip, provide:
- A catchy title (max 60 chars)
- Start and end timestamps (in seconds)
- Virality score (0-100)
- Category (one of: hook, story, insight, humor, emotional, controversial, educational)
- Brief reasoning for why this clip would go viral
Rules:
- Clips should be {min_dur}-{max_dur} seconds long
- Identify {target} clips, ranked by virality potential
- Clips should start and end at natural sentence boundaries
- Prefer clips with strong hooks in the first 3 seconds
- Look for emotional peaks, surprising statements, quotable moments
- Avoid clips that start mid-sentence or end abruptly
Respond ONLY with valid JSON in this exact format:
{{
"clips": [
{{
"title": "Clip title here",
"start_time": 12.5,
"end_time": 45.2,
"virality_score": 85,
"category": "hook",
"reasoning": "Why this clip would perform well"
}}
]
}}"""
async def analyze_transcript(
transcript: dict,
video_title: str = "",
video_duration: float = 0,
) -> list[dict]:
"""Use Ollama to identify the best clips from a transcript.
Args:
transcript: dict with 'text', 'words', 'segments' from transcription service
video_title: original video title for context
video_duration: total video duration in seconds
Returns:
List of clip dicts with title, start_time, end_time, virality_score, category, reasoning
"""
# Build timestamped transcript for the LLM
text = transcript.get("text", "")
segments = transcript.get("segments", [])
if segments:
timestamped = "\n".join(
f"[{_fmt_time(s.get('start', 0))} - {_fmt_time(s.get('end', 0))}] "
f"{s.get('text', '').strip()}"
for s in segments
)
else:
# Fall back to plain text with rough time estimates
timestamped = text
system = SYSTEM_PROMPT.format(
min_dur=settings.clip_min_duration,
max_dur=settings.clip_max_duration,
target=settings.target_clips,
)
user_prompt = f"""Video Title: {video_title}
Video Duration: {_fmt_time(video_duration)}
Transcript:
{timestamped}
Identify the {settings.target_clips} best viral clips from this transcript."""
logger.info(f"Sending transcript to Ollama ({settings.ollama_model})...")
async with httpx.AsyncClient(timeout=1800.0) as client:
response = await client.post(
f"{settings.ollama_url}/api/chat",
json={
"model": settings.ollama_model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user_prompt},
],
"stream": False,
"options": {
"temperature": 0.3,
"num_predict": 4096,
},
},
)
response.raise_for_status()
result = response.json()
content = result.get("message", {}).get("content", "")
clips = _parse_clips(content, video_duration)
logger.info(f"AI identified {len(clips)} clips")
return clips
def _parse_clips(content: str, video_duration: float) -> list[dict]:
"""Parse LLM response into clip list, handling imperfect JSON."""
# Try to extract JSON from response
json_match = re.search(r"\{[\s\S]*\}", content)
if not json_match:
logger.error(f"No JSON found in LLM response: {content[:200]}")
return []
try:
data = json.loads(json_match.group())
except json.JSONDecodeError:
# Try to fix common JSON issues
fixed = json_match.group()
fixed = re.sub(r",\s*}", "}", fixed)
fixed = re.sub(r",\s*]", "]", fixed)
try:
data = json.loads(fixed)
except json.JSONDecodeError:
logger.error(f"Failed to parse LLM JSON: {content[:200]}")
return []
raw_clips = data.get("clips", [])
clips = []
for c in raw_clips:
start = float(c.get("start_time", 0))
end = float(c.get("end_time", 0))
# Validate
if end <= start:
continue
if start < 0:
start = 0
if end > video_duration and video_duration > 0:
end = video_duration
clips.append({
"title": str(c.get("title", "Untitled"))[:100],
"start_time": round(start, 2),
"end_time": round(end, 2),
"virality_score": max(0, min(100, float(c.get("virality_score", 50)))),
"category": str(c.get("category", "general")),
"reasoning": str(c.get("reasoning", "")),
})
# Sort by virality score descending
clips.sort(key=lambda x: x["virality_score"], reverse=True)
return clips
def _fmt_time(seconds: float) -> str:
"""Format seconds as MM:SS."""
m, s = divmod(int(seconds), 60)
return f"{m:02d}:{s:02d}"