Aggressively optimize Ollama CPU inference speed

- Warm up both models on startup with keep_alive=24h (no cold starts)
- Use 16 threads for inference (server has 20 cores)
- Reduce context window to 1024 tokens, max output to 256
- Persistent httpx client for embedding calls (skip TCP handshake)
- Trim RAG chunks to 300 chars, history to 4 messages
- Shorter system prompt and context wrapper

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jeff Emmett 2026-02-17 01:12:04 -07:00
parent 3215283f97
commit 08be7716f9
4 changed files with 64 additions and 29 deletions

View File

@ -12,11 +12,23 @@ from app.models import Experience, Substance, DocumentChunk
logger = logging.getLogger(__name__)
# Persistent HTTP client for Ollama calls (avoids TCP handshake per request)
_ollama_client: httpx.AsyncClient | None = None
def _get_ollama_client() -> httpx.AsyncClient:
global _ollama_client
if _ollama_client is None or _ollama_client.is_closed:
_ollama_client = httpx.AsyncClient(
base_url=settings.ollama_base_url,
timeout=httpx.Timeout(connect=10, read=120, write=10, pool=10),
)
return _ollama_client
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
"""Split text into overlapping chunks by approximate token count (words / 0.75)."""
words = text.split()
# Approximate: 1 token ~ 0.75 words
words_per_chunk = int(chunk_size * 0.75)
words_overlap = int(overlap * 0.75)
@ -36,25 +48,18 @@ def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]
async def get_embedding(text: str, client: httpx.AsyncClient | None = None) -> list[float]:
"""Get embedding vector for text using Ollama."""
should_close = False
if client is None:
client = httpx.AsyncClient(timeout=60)
should_close = True
c = client or _get_ollama_client()
try:
resp = await client.post(
f"{settings.ollama_base_url}/api/embeddings",
json={
"model": settings.ollama_embed_model,
"prompt": text,
},
)
resp.raise_for_status()
data = resp.json()
return data["embedding"]
finally:
if should_close:
await client.aclose()
resp = await c.post(
"/api/embeddings",
json={
"model": settings.ollama_embed_model,
"prompt": text,
},
)
resp.raise_for_status()
data = resp.json()
return data["embedding"]
async def get_embeddings_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:

View File

@ -8,6 +8,8 @@ from app.config import settings
logger = logging.getLogger(__name__)
_stream_timeout = httpx.Timeout(connect=10, read=300, write=10, pool=10)
async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerator[str, None]:
"""Stream a chat completion from Ollama."""
@ -20,14 +22,16 @@ async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerato
"model": settings.ollama_chat_model,
"messages": all_messages,
"stream": True,
"keep_alive": "24h",
"options": {
"num_ctx": 2048,
"num_predict": 512,
"num_ctx": 1024,
"num_predict": 256,
"num_thread": 16,
"temperature": 0.7,
},
}
timeout = httpx.Timeout(connect=30, read=600, write=30, pool=30)
async with httpx.AsyncClient(timeout=timeout) as client:
async with httpx.AsyncClient(timeout=_stream_timeout) as client:
async with client.stream(
"POST",
f"{settings.ollama_base_url}/api/chat",

View File

@ -5,6 +5,7 @@ import uuid
from contextlib import asynccontextmanager
from pathlib import Path
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
@ -20,11 +21,38 @@ logger = logging.getLogger(__name__)
sessions: dict[str, list[dict]] = {}
async def warmup_models():
"""Pre-load chat and embedding models into Ollama memory."""
try:
async with httpx.AsyncClient(timeout=120) as client:
# Warm up embedding model
logger.info(f"Warming up embedding model: {settings.ollama_embed_model}")
await client.post(
f"{settings.ollama_base_url}/api/embeddings",
json={"model": settings.ollama_embed_model, "prompt": "warmup"},
)
# Warm up chat model with keep_alive to hold in memory
logger.info(f"Warming up chat model: {settings.ollama_chat_model}")
await client.post(
f"{settings.ollama_base_url}/api/generate",
json={
"model": settings.ollama_chat_model,
"prompt": "hi",
"options": {"num_predict": 1, "num_thread": 16},
"keep_alive": "24h",
},
)
logger.info("Models warmed up and loaded into memory.")
except Exception as e:
logger.warning(f"Model warmup failed (will load on first request): {e}")
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Initializing database...")
await init_db()
logger.info("Database ready.")
await warmup_models()
yield

View File

@ -76,8 +76,7 @@ def build_context_prompt(chunks: list[dict]) -> str:
header += f" | Substance: {metadata['substance']}"
header += " ---"
# Limit each chunk to avoid overwhelming the LLM
content = chunk["content"][:500]
content = chunk["content"][:300]
context_parts.append(f"{header}\n{content}")
return "\n\n".join(context_parts)
@ -93,13 +92,12 @@ async def chat_stream(
# Build the context-augmented system prompt
context_text = build_context_prompt(chunks)
full_system = f"{SYSTEM_PROMPT}\n\n--- RELEVANT EROWID DATA ---\n{context_text}\n--- END EROWID DATA ---"
full_system = f"{SYSTEM_PROMPT}\n\nContext:\n{context_text}"
# Build message history
# Build message history (keep minimal for speed)
messages = []
if conversation_history:
# Keep last 6 messages for context
messages = conversation_history[-6:]
messages = conversation_history[-4:]
messages.append({"role": "user", "content": user_message})