"""AI clip analysis using Ollama (local LLM).""" import json import logging import re import httpx from app.config import settings logger = logging.getLogger(__name__) # Max transcript chars to send to the LLM. # Cloud APIs (Gemini, OpenRouter) handle 20K+ easily. # CPU inference should use smaller values (~6K). MAX_TRANSCRIPT_CHARS = 20000 SYSTEM_PROMPT = """You are a viral video clip analyst. Given a video transcript with timestamps, identify the best short clips that would perform well on social media (TikTok, YouTube Shorts, Instagram Reels). For each clip, provide: - A catchy title (max 60 chars) - Start and end timestamps (in seconds) - Virality score (0-100) - Category (one of: hook, story, insight, humor, emotional, controversial, educational) - Brief reasoning for why this clip would go viral Rules: - Clips should be {min_dur}-{max_dur} seconds long - Identify {target} clips, ranked by virality potential - Clips should start and end at natural sentence boundaries - Prefer clips with strong hooks in the first 3 seconds - Look for emotional peaks, surprising statements, quotable moments - Avoid clips that start mid-sentence or end abruptly Respond ONLY with valid JSON in this exact format: {{ "clips": [ {{ "title": "Clip title here", "start_time": 12.5, "end_time": 45.2, "virality_score": 85, "category": "hook", "reasoning": "Why this clip would perform well" }} ] }}""" def _sample_segments(segments: list[dict], max_chars: int) -> list[dict]: """Sample segments evenly across the video to fit within max_chars. Takes segments from the beginning, middle, and end so clips can be identified from all parts of the video, not just the first N minutes. """ # Build all lines first to measure sizes lines = [] for s in segments: line = ( f"[{s.get('start', 0):.1f}s - {s.get('end', 0):.1f}s] " f"{s.get('text', '').strip()}" ) lines.append((s, line)) total_chars = sum(len(l) + 1 for _, l in lines) # +1 for newline if total_chars <= max_chars: return segments # Need to sample. Keep every Nth segment to stay under the limit. # Always include first and last few segments for context. keep_chars = 0 kept = [] # Reserve ~15% for start (hook detection), ~15% for end, ~70% sampled middle start_budget = int(max_chars * 0.15) end_budget = int(max_chars * 0.15) middle_budget = max_chars - start_budget - end_budget # Take segments from the start for s, line in lines: if keep_chars + len(line) + 1 > start_budget: break kept.append(s) keep_chars += len(line) + 1 start_count = len(kept) # Take segments from the end (reversed) end_segs = [] end_chars = 0 for s, line in reversed(lines[start_count:]): if end_chars + len(line) + 1 > end_budget: break end_segs.append(s) end_chars += len(line) + 1 end_segs.reverse() end_start_idx = len(lines) - len(end_segs) # Sample evenly from the middle middle_lines = lines[start_count:end_start_idx] if middle_lines: # Calculate skip interval to fit middle_budget middle_chars = 0 middle_kept = [] avg_line_len = sum(len(l) + 1 for _, l in middle_lines) / len(middle_lines) target_lines = int(middle_budget / avg_line_len) step = max(1, len(middle_lines) // max(1, target_lines)) for i in range(0, len(middle_lines), step): s, line = middle_lines[i] if middle_chars + len(line) + 1 > middle_budget: break middle_kept.append(s) middle_chars += len(line) + 1 kept.extend(middle_kept) kept.extend(end_segs) logger.info( f"Sampled {len(kept)}/{len(segments)} segments " f"({total_chars} -> ~{max_chars} chars) for LLM analysis" ) return kept async def analyze_transcript( transcript: dict, video_title: str = "", video_duration: float = 0, ) -> list[dict]: """Use Ollama to identify the best clips from a transcript. Args: transcript: dict with 'text', 'words', 'segments' from transcription service video_title: original video title for context video_duration: total video duration in seconds Returns: List of clip dicts with title, start_time, end_time, virality_score, category, reasoning """ # Build timestamped transcript for the LLM text = transcript.get("text", "") segments = transcript.get("segments", []) if segments: sampled = _sample_segments(segments, MAX_TRANSCRIPT_CHARS) timestamped = "\n".join( f"[{s.get('start', 0):.1f}s - {s.get('end', 0):.1f}s] " f"{s.get('text', '').strip()}" for s in sampled ) else: # Fall back to plain text, truncated timestamped = text[:MAX_TRANSCRIPT_CHARS] system = SYSTEM_PROMPT.format( min_dur=settings.clip_min_duration, max_dur=settings.clip_max_duration, target=settings.target_clips, ) user_prompt = f"""Video Title: {video_title} Video Duration: {video_duration:.0f} seconds Transcript: {timestamped} Identify the {settings.target_clips} best viral clips from this transcript.""" if settings.ai_provider == "openai" and settings.openai_api_url: content = await _call_openai(system, user_prompt) else: content = await _call_ollama(system, user_prompt) clips = _parse_clips(content, video_duration) logger.info(f"AI identified {len(clips)} clips") return clips async def _call_ollama(system: str, user_prompt: str) -> str: """Call Ollama local API using streaming to avoid read timeouts.""" logger.info( f"Sending to Ollama ({settings.ollama_model}), " f"prompt size: {len(user_prompt)} chars..." ) # Use streaming to prevent httpx ReadTimeout on slow CPU inference. # Each streamed chunk resets the read timeout. timeout = httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0) content_parts = [] async with httpx.AsyncClient(timeout=timeout) as client: async with client.stream( "POST", f"{settings.ollama_url}/api/chat", json={ "model": settings.ollama_model, "messages": [ {"role": "system", "content": system}, {"role": "user", "content": user_prompt}, ], "stream": True, "options": { "temperature": 0.3, "num_predict": 2048, "num_ctx": 4096, }, }, ) as response: response.raise_for_status() async for line in response.aiter_lines(): if not line.strip(): continue try: chunk = json.loads(line) msg = chunk.get("message", {}).get("content", "") if msg: content_parts.append(msg) if chunk.get("done"): break except json.JSONDecodeError: continue content = "".join(content_parts) logger.info(f"Ollama response: {len(content)} chars") return content async def _call_openai(system: str, user_prompt: str) -> str: """Call OpenAI-compatible API (RunPod vLLM, OpenRouter, etc.).""" logger.info( f"Sending to OpenAI API ({settings.openai_model}), " f"prompt size: {len(user_prompt)} chars..." ) headers = {"Content-Type": "application/json"} if settings.openai_api_key: headers["Authorization"] = f"Bearer {settings.openai_api_key}" async with httpx.AsyncClient(timeout=120.0) as client: response = await client.post( f"{settings.openai_api_url}/chat/completions", headers=headers, json={ "model": settings.openai_model, "messages": [ {"role": "system", "content": system}, {"role": "user", "content": user_prompt}, ], "temperature": 0.3, "max_tokens": 2048, }, ) response.raise_for_status() result = response.json() content = result["choices"][0]["message"]["content"] logger.info(f"OpenAI API response ({len(content)} chars): {content[:300]}") return content def _try_parse_json(text: str) -> dict | None: """Try multiple strategies to parse JSON from LLM output.""" # Attempt 1: direct parse try: return json.loads(text) except json.JSONDecodeError: pass # Attempt 2: fix trailing commas fixed = re.sub(r",\s*}", "}", text) fixed = re.sub(r",\s*]", "]", fixed) try: return json.loads(fixed) except json.JSONDecodeError: pass # Attempt 3: fix unescaped control characters in string values # Replace literal newlines/tabs inside JSON strings fixed2 = re.sub(r'(?<=": ")(.*?)(?="[,\s}])', lambda m: m.group().replace('\n', '\\n').replace('\t', '\\t'), fixed, flags=re.DOTALL) try: return json.loads(fixed2) except json.JSONDecodeError: pass # Attempt 4: extract individual clip objects with a more lenient approach try: clips = [] clip_pattern = re.compile( r'"title"\s*:\s*"([^"]*)".*?' r'"start_time"\s*:\s*([0-9.]+).*?' r'"end_time"\s*:\s*([0-9.]+).*?' r'"virality_score"\s*:\s*([0-9.]+).*?' r'"category"\s*:\s*"([^"]*)"', re.DOTALL ) for m in clip_pattern.finditer(text): clips.append({ "title": m.group(1), "start_time": float(m.group(2)), "end_time": float(m.group(3)), "virality_score": float(m.group(4)), "category": m.group(5), "reasoning": "", }) if clips: logger.info(f"Regex fallback extracted {len(clips)} clips") return {"clips": clips} except Exception: pass return None def _parse_clips(content: str, video_duration: float) -> list[dict]: """Parse LLM response into clip list, handling imperfect JSON.""" # Strip markdown code fences (e.g. ```json ... ```) content = re.sub(r"```(?:json)?\s*", "", content).strip() # Try to extract JSON from response json_match = re.search(r"\{[\s\S]*\}", content) if not json_match: logger.error(f"No JSON found in LLM response: {content[:200]}") return [] raw_json = json_match.group() logger.info(f"Extracted JSON ({len(raw_json)} chars)") data = _try_parse_json(raw_json) if data is None: logger.error(f"Failed to parse LLM JSON. Raw content: {content[:1000]}") return [] raw_clips = data.get("clips", []) clips = [] for c in raw_clips: start = float(c.get("start_time", 0)) end = float(c.get("end_time", 0)) # Validate if end <= start: continue if start < 0: start = 0 if end > video_duration and video_duration > 0: end = video_duration clips.append({ "title": str(c.get("title", "Untitled"))[:100], "start_time": round(start, 2), "end_time": round(end, 2), "virality_score": max(0, min(100, float(c.get("virality_score", 50)))), "category": str(c.get("category", "general")), "reasoning": str(c.get("reasoning", "")), }) # Sort by virality score descending clips.sort(key=lambda x: x["virality_score"], reverse=True) return clips