From 08be7716f9bedb3ccd967e13c074876ecc92d08b Mon Sep 17 00:00:00 2001 From: Jeff Emmett Date: Tue, 17 Feb 2026 01:12:04 -0700 Subject: [PATCH] Aggressively optimize Ollama CPU inference speed - Warm up both models on startup with keep_alive=24h (no cold starts) - Use 16 threads for inference (server has 20 cores) - Reduce context window to 1024 tokens, max output to 256 - Persistent httpx client for embedding calls (skip TCP handshake) - Trim RAG chunks to 300 chars, history to 4 messages - Shorter system prompt and context wrapper Co-Authored-By: Claude Opus 4.6 --- app/embeddings.py | 43 ++++++++++++++++++++++++------------------- app/llm.py | 12 ++++++++---- app/main.py | 28 ++++++++++++++++++++++++++++ app/rag.py | 10 ++++------ 4 files changed, 64 insertions(+), 29 deletions(-) diff --git a/app/embeddings.py b/app/embeddings.py index c5201f8..7b86899 100644 --- a/app/embeddings.py +++ b/app/embeddings.py @@ -12,11 +12,23 @@ from app.models import Experience, Substance, DocumentChunk logger = logging.getLogger(__name__) +# Persistent HTTP client for Ollama calls (avoids TCP handshake per request) +_ollama_client: httpx.AsyncClient | None = None + + +def _get_ollama_client() -> httpx.AsyncClient: + global _ollama_client + if _ollama_client is None or _ollama_client.is_closed: + _ollama_client = httpx.AsyncClient( + base_url=settings.ollama_base_url, + timeout=httpx.Timeout(connect=10, read=120, write=10, pool=10), + ) + return _ollama_client + def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]: """Split text into overlapping chunks by approximate token count (words / 0.75).""" words = text.split() - # Approximate: 1 token ~ 0.75 words words_per_chunk = int(chunk_size * 0.75) words_overlap = int(overlap * 0.75) @@ -36,25 +48,18 @@ def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str] async def get_embedding(text: str, client: httpx.AsyncClient | None = None) -> list[float]: """Get embedding vector for text using Ollama.""" - should_close = False - if client is None: - client = httpx.AsyncClient(timeout=60) - should_close = True + c = client or _get_ollama_client() - try: - resp = await client.post( - f"{settings.ollama_base_url}/api/embeddings", - json={ - "model": settings.ollama_embed_model, - "prompt": text, - }, - ) - resp.raise_for_status() - data = resp.json() - return data["embedding"] - finally: - if should_close: - await client.aclose() + resp = await c.post( + "/api/embeddings", + json={ + "model": settings.ollama_embed_model, + "prompt": text, + }, + ) + resp.raise_for_status() + data = resp.json() + return data["embedding"] async def get_embeddings_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]: diff --git a/app/llm.py b/app/llm.py index 11e467f..6c67d81 100644 --- a/app/llm.py +++ b/app/llm.py @@ -8,6 +8,8 @@ from app.config import settings logger = logging.getLogger(__name__) +_stream_timeout = httpx.Timeout(connect=10, read=300, write=10, pool=10) + async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerator[str, None]: """Stream a chat completion from Ollama.""" @@ -20,14 +22,16 @@ async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerato "model": settings.ollama_chat_model, "messages": all_messages, "stream": True, + "keep_alive": "24h", "options": { - "num_ctx": 2048, - "num_predict": 512, + "num_ctx": 1024, + "num_predict": 256, + "num_thread": 16, + "temperature": 0.7, }, } - timeout = httpx.Timeout(connect=30, read=600, write=30, pool=30) - async with httpx.AsyncClient(timeout=timeout) as client: + async with httpx.AsyncClient(timeout=_stream_timeout) as client: async with client.stream( "POST", f"{settings.ollama_base_url}/api/chat", diff --git a/app/main.py b/app/main.py index b4eed85..1f65af4 100644 --- a/app/main.py +++ b/app/main.py @@ -5,6 +5,7 @@ import uuid from contextlib import asynccontextmanager from pathlib import Path +import httpx from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse from fastapi.staticfiles import StaticFiles @@ -20,11 +21,38 @@ logger = logging.getLogger(__name__) sessions: dict[str, list[dict]] = {} +async def warmup_models(): + """Pre-load chat and embedding models into Ollama memory.""" + try: + async with httpx.AsyncClient(timeout=120) as client: + # Warm up embedding model + logger.info(f"Warming up embedding model: {settings.ollama_embed_model}") + await client.post( + f"{settings.ollama_base_url}/api/embeddings", + json={"model": settings.ollama_embed_model, "prompt": "warmup"}, + ) + # Warm up chat model with keep_alive to hold in memory + logger.info(f"Warming up chat model: {settings.ollama_chat_model}") + await client.post( + f"{settings.ollama_base_url}/api/generate", + json={ + "model": settings.ollama_chat_model, + "prompt": "hi", + "options": {"num_predict": 1, "num_thread": 16}, + "keep_alive": "24h", + }, + ) + logger.info("Models warmed up and loaded into memory.") + except Exception as e: + logger.warning(f"Model warmup failed (will load on first request): {e}") + + @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Initializing database...") await init_db() logger.info("Database ready.") + await warmup_models() yield diff --git a/app/rag.py b/app/rag.py index 25cfd4d..50d712d 100644 --- a/app/rag.py +++ b/app/rag.py @@ -76,8 +76,7 @@ def build_context_prompt(chunks: list[dict]) -> str: header += f" | Substance: {metadata['substance']}" header += " ---" - # Limit each chunk to avoid overwhelming the LLM - content = chunk["content"][:500] + content = chunk["content"][:300] context_parts.append(f"{header}\n{content}") return "\n\n".join(context_parts) @@ -93,13 +92,12 @@ async def chat_stream( # Build the context-augmented system prompt context_text = build_context_prompt(chunks) - full_system = f"{SYSTEM_PROMPT}\n\n--- RELEVANT EROWID DATA ---\n{context_text}\n--- END EROWID DATA ---" + full_system = f"{SYSTEM_PROMPT}\n\nContext:\n{context_text}" - # Build message history + # Build message history (keep minimal for speed) messages = [] if conversation_history: - # Keep last 6 messages for context - messages = conversation_history[-6:] + messages = conversation_history[-4:] messages.append({"role": "user", "content": user_message})