Aggressively optimize Ollama CPU inference speed

- Warm up both models on startup with keep_alive=24h (no cold starts)
- Use 16 threads for inference (server has 20 cores)
- Reduce context window to 1024 tokens, max output to 256
- Persistent httpx client for embedding calls (skip TCP handshake)
- Trim RAG chunks to 300 chars, history to 4 messages
- Shorter system prompt and context wrapper

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jeff Emmett 2026-02-17 01:12:04 -07:00
parent 3215283f97
commit 08be7716f9
4 changed files with 64 additions and 29 deletions

View File

@ -12,11 +12,23 @@ from app.models import Experience, Substance, DocumentChunk
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Persistent HTTP client for Ollama calls (avoids TCP handshake per request)
_ollama_client: httpx.AsyncClient | None = None
def _get_ollama_client() -> httpx.AsyncClient:
global _ollama_client
if _ollama_client is None or _ollama_client.is_closed:
_ollama_client = httpx.AsyncClient(
base_url=settings.ollama_base_url,
timeout=httpx.Timeout(connect=10, read=120, write=10, pool=10),
)
return _ollama_client
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]: def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
"""Split text into overlapping chunks by approximate token count (words / 0.75).""" """Split text into overlapping chunks by approximate token count (words / 0.75)."""
words = text.split() words = text.split()
# Approximate: 1 token ~ 0.75 words
words_per_chunk = int(chunk_size * 0.75) words_per_chunk = int(chunk_size * 0.75)
words_overlap = int(overlap * 0.75) words_overlap = int(overlap * 0.75)
@ -36,25 +48,18 @@ def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]
async def get_embedding(text: str, client: httpx.AsyncClient | None = None) -> list[float]: async def get_embedding(text: str, client: httpx.AsyncClient | None = None) -> list[float]:
"""Get embedding vector for text using Ollama.""" """Get embedding vector for text using Ollama."""
should_close = False c = client or _get_ollama_client()
if client is None:
client = httpx.AsyncClient(timeout=60)
should_close = True
try: resp = await c.post(
resp = await client.post( "/api/embeddings",
f"{settings.ollama_base_url}/api/embeddings", json={
json={ "model": settings.ollama_embed_model,
"model": settings.ollama_embed_model, "prompt": text,
"prompt": text, },
}, )
) resp.raise_for_status()
resp.raise_for_status() data = resp.json()
data = resp.json() return data["embedding"]
return data["embedding"]
finally:
if should_close:
await client.aclose()
async def get_embeddings_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]: async def get_embeddings_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:

View File

@ -8,6 +8,8 @@ from app.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_stream_timeout = httpx.Timeout(connect=10, read=300, write=10, pool=10)
async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerator[str, None]: async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerator[str, None]:
"""Stream a chat completion from Ollama.""" """Stream a chat completion from Ollama."""
@ -20,14 +22,16 @@ async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerato
"model": settings.ollama_chat_model, "model": settings.ollama_chat_model,
"messages": all_messages, "messages": all_messages,
"stream": True, "stream": True,
"keep_alive": "24h",
"options": { "options": {
"num_ctx": 2048, "num_ctx": 1024,
"num_predict": 512, "num_predict": 256,
"num_thread": 16,
"temperature": 0.7,
}, },
} }
timeout = httpx.Timeout(connect=30, read=600, write=30, pool=30) async with httpx.AsyncClient(timeout=_stream_timeout) as client:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream( async with client.stream(
"POST", "POST",
f"{settings.ollama_base_url}/api/chat", f"{settings.ollama_base_url}/api/chat",

View File

@ -5,6 +5,7 @@ import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
import httpx
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
@ -20,11 +21,38 @@ logger = logging.getLogger(__name__)
sessions: dict[str, list[dict]] = {} sessions: dict[str, list[dict]] = {}
async def warmup_models():
"""Pre-load chat and embedding models into Ollama memory."""
try:
async with httpx.AsyncClient(timeout=120) as client:
# Warm up embedding model
logger.info(f"Warming up embedding model: {settings.ollama_embed_model}")
await client.post(
f"{settings.ollama_base_url}/api/embeddings",
json={"model": settings.ollama_embed_model, "prompt": "warmup"},
)
# Warm up chat model with keep_alive to hold in memory
logger.info(f"Warming up chat model: {settings.ollama_chat_model}")
await client.post(
f"{settings.ollama_base_url}/api/generate",
json={
"model": settings.ollama_chat_model,
"prompt": "hi",
"options": {"num_predict": 1, "num_thread": 16},
"keep_alive": "24h",
},
)
logger.info("Models warmed up and loaded into memory.")
except Exception as e:
logger.warning(f"Model warmup failed (will load on first request): {e}")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
logger.info("Initializing database...") logger.info("Initializing database...")
await init_db() await init_db()
logger.info("Database ready.") logger.info("Database ready.")
await warmup_models()
yield yield

View File

@ -76,8 +76,7 @@ def build_context_prompt(chunks: list[dict]) -> str:
header += f" | Substance: {metadata['substance']}" header += f" | Substance: {metadata['substance']}"
header += " ---" header += " ---"
# Limit each chunk to avoid overwhelming the LLM content = chunk["content"][:300]
content = chunk["content"][:500]
context_parts.append(f"{header}\n{content}") context_parts.append(f"{header}\n{content}")
return "\n\n".join(context_parts) return "\n\n".join(context_parts)
@ -93,13 +92,12 @@ async def chat_stream(
# Build the context-augmented system prompt # Build the context-augmented system prompt
context_text = build_context_prompt(chunks) context_text = build_context_prompt(chunks)
full_system = f"{SYSTEM_PROMPT}\n\n--- RELEVANT EROWID DATA ---\n{context_text}\n--- END EROWID DATA ---" full_system = f"{SYSTEM_PROMPT}\n\nContext:\n{context_text}"
# Build message history # Build message history (keep minimal for speed)
messages = [] messages = []
if conversation_history: if conversation_history:
# Keep last 6 messages for context messages = conversation_history[-4:]
messages = conversation_history[-6:]
messages.append({"role": "user", "content": user_message}) messages.append({"role": "user", "content": user_message})