diff --git a/backend/app/config.py b/backend/app/config.py index f8ca9be..c068b54 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -12,9 +12,13 @@ class Settings(BaseSettings): whisper_api_url: str = "https://whisper.jeffemmett.com" whisper_model: str = "deepdml/faster-whisper-large-v3-turbo-ct2" - # Ollama + # AI Analysis - supports "ollama" or "openai" (OpenAI-compatible API) + ai_provider: str = "ollama" # "ollama" or "openai" ollama_url: str = "http://host.docker.internal:11434" ollama_model: str = "llama3.1:8b" + openai_api_url: str = "" # e.g. https://api.runpod.ai/v2/{endpoint_id}/openai/v1 + openai_api_key: str = "" + openai_model: str = "" # e.g. meta-llama/Llama-3.1-8B-Instruct # Storage media_dir: str = "/data/media" diff --git a/backend/app/services/ai_analysis.py b/backend/app/services/ai_analysis.py index c6e16b5..7f3fc9f 100644 --- a/backend/app/services/ai_analysis.py +++ b/backend/app/services/ai_analysis.py @@ -10,6 +10,11 @@ from app.config import settings logger = logging.getLogger(__name__) +# Max transcript chars to send to the LLM. +# Cloud APIs (Gemini, OpenRouter) handle 20K+ easily. +# CPU inference should use smaller values (~6K). +MAX_TRANSCRIPT_CHARS = 20000 + SYSTEM_PROMPT = """You are a viral video clip analyst. Given a video transcript with timestamps, identify the best short clips that would perform well on social media (TikTok, YouTube Shorts, Instagram Reels). For each clip, provide: @@ -42,6 +47,85 @@ Respond ONLY with valid JSON in this exact format: }}""" +def _sample_segments(segments: list[dict], max_chars: int) -> list[dict]: + """Sample segments evenly across the video to fit within max_chars. + + Takes segments from the beginning, middle, and end so clips can be + identified from all parts of the video, not just the first N minutes. + """ + # Build all lines first to measure sizes + lines = [] + for s in segments: + line = ( + f"[{_fmt_time(s.get('start', 0))} - {_fmt_time(s.get('end', 0))}] " + f"{s.get('text', '').strip()}" + ) + lines.append((s, line)) + + total_chars = sum(len(l) + 1 for _, l in lines) # +1 for newline + + if total_chars <= max_chars: + return segments + + # Need to sample. Keep every Nth segment to stay under the limit. + # Always include first and last few segments for context. + keep_chars = 0 + kept = [] + + # Reserve ~15% for start (hook detection), ~15% for end, ~70% sampled middle + start_budget = int(max_chars * 0.15) + end_budget = int(max_chars * 0.15) + middle_budget = max_chars - start_budget - end_budget + + # Take segments from the start + for s, line in lines: + if keep_chars + len(line) + 1 > start_budget: + break + kept.append(s) + keep_chars += len(line) + 1 + + start_count = len(kept) + + # Take segments from the end (reversed) + end_segs = [] + end_chars = 0 + for s, line in reversed(lines[start_count:]): + if end_chars + len(line) + 1 > end_budget: + break + end_segs.append(s) + end_chars += len(line) + 1 + + end_segs.reverse() + end_start_idx = len(lines) - len(end_segs) + + # Sample evenly from the middle + middle_lines = lines[start_count:end_start_idx] + if middle_lines: + # Calculate skip interval to fit middle_budget + middle_chars = 0 + middle_kept = [] + avg_line_len = sum(len(l) + 1 for _, l in middle_lines) / len(middle_lines) + target_lines = int(middle_budget / avg_line_len) + step = max(1, len(middle_lines) // max(1, target_lines)) + + for i in range(0, len(middle_lines), step): + s, line = middle_lines[i] + if middle_chars + len(line) + 1 > middle_budget: + break + middle_kept.append(s) + middle_chars += len(line) + 1 + + kept.extend(middle_kept) + + kept.extend(end_segs) + + logger.info( + f"Sampled {len(kept)}/{len(segments)} segments " + f"({total_chars} -> ~{max_chars} chars) for LLM analysis" + ) + return kept + + async def analyze_transcript( transcript: dict, video_title: str = "", @@ -62,14 +146,15 @@ async def analyze_transcript( segments = transcript.get("segments", []) if segments: + sampled = _sample_segments(segments, MAX_TRANSCRIPT_CHARS) timestamped = "\n".join( f"[{_fmt_time(s.get('start', 0))} - {_fmt_time(s.get('end', 0))}] " f"{s.get('text', '').strip()}" - for s in segments + for s in sampled ) else: - # Fall back to plain text with rough time estimates - timestamped = text + # Fall back to plain text, truncated + timestamped = text[:MAX_TRANSCRIPT_CHARS] system = SYSTEM_PROMPT.format( min_dur=settings.clip_min_duration, @@ -85,10 +170,30 @@ Transcript: Identify the {settings.target_clips} best viral clips from this transcript.""" - logger.info(f"Sending transcript to Ollama ({settings.ollama_model})...") + if settings.ai_provider == "openai" and settings.openai_api_url: + content = await _call_openai(system, user_prompt) + else: + content = await _call_ollama(system, user_prompt) - async with httpx.AsyncClient(timeout=1800.0) as client: - response = await client.post( + clips = _parse_clips(content, video_duration) + + logger.info(f"AI identified {len(clips)} clips") + return clips + + +async def _call_ollama(system: str, user_prompt: str) -> str: + """Call Ollama local API using streaming to avoid read timeouts.""" + logger.info( + f"Sending to Ollama ({settings.ollama_model}), " + f"prompt size: {len(user_prompt)} chars..." + ) + # Use streaming to prevent httpx ReadTimeout on slow CPU inference. + # Each streamed chunk resets the read timeout. + timeout = httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0) + content_parts = [] + async with httpx.AsyncClient(timeout=timeout) as client: + async with client.stream( + "POST", f"{settings.ollama_url}/api/chat", json={ "model": settings.ollama_model, @@ -96,21 +201,60 @@ Identify the {settings.target_clips} best viral clips from this transcript.""" {"role": "system", "content": system}, {"role": "user", "content": user_prompt}, ], - "stream": False, + "stream": True, "options": { "temperature": 0.3, - "num_predict": 4096, + "num_predict": 2048, + "num_ctx": 4096, }, }, + ) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if not line.strip(): + continue + try: + chunk = json.loads(line) + msg = chunk.get("message", {}).get("content", "") + if msg: + content_parts.append(msg) + if chunk.get("done"): + break + except json.JSONDecodeError: + continue + + content = "".join(content_parts) + logger.info(f"Ollama response: {len(content)} chars") + return content + + +async def _call_openai(system: str, user_prompt: str) -> str: + """Call OpenAI-compatible API (RunPod vLLM, OpenRouter, etc.).""" + logger.info( + f"Sending to OpenAI API ({settings.openai_model}), " + f"prompt size: {len(user_prompt)} chars..." + ) + headers = {"Content-Type": "application/json"} + if settings.openai_api_key: + headers["Authorization"] = f"Bearer {settings.openai_api_key}" + + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + f"{settings.openai_api_url}/chat/completions", + headers=headers, + json={ + "model": settings.openai_model, + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user_prompt}, + ], + "temperature": 0.3, + "max_tokens": 2048, + }, ) response.raise_for_status() result = response.json() - - content = result.get("message", {}).get("content", "") - clips = _parse_clips(content, video_duration) - - logger.info(f"AI identified {len(clips)} clips") - return clips + return result["choices"][0]["message"]["content"] def _parse_clips(content: str, video_duration: float) -> list[dict]: