Aggressively optimize Ollama CPU inference speed
- Warm up both models on startup with keep_alive=24h (no cold starts) - Use 16 threads for inference (server has 20 cores) - Reduce context window to 1024 tokens, max output to 256 - Persistent httpx client for embedding calls (skip TCP handshake) - Trim RAG chunks to 300 chars, history to 4 messages - Shorter system prompt and context wrapper Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3215283f97
commit
08be7716f9
|
|
@ -12,11 +12,23 @@ from app.models import Experience, Substance, DocumentChunk
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Persistent HTTP client for Ollama calls (avoids TCP handshake per request)
|
||||
_ollama_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
def _get_ollama_client() -> httpx.AsyncClient:
|
||||
global _ollama_client
|
||||
if _ollama_client is None or _ollama_client.is_closed:
|
||||
_ollama_client = httpx.AsyncClient(
|
||||
base_url=settings.ollama_base_url,
|
||||
timeout=httpx.Timeout(connect=10, read=120, write=10, pool=10),
|
||||
)
|
||||
return _ollama_client
|
||||
|
||||
|
||||
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
|
||||
"""Split text into overlapping chunks by approximate token count (words / 0.75)."""
|
||||
words = text.split()
|
||||
# Approximate: 1 token ~ 0.75 words
|
||||
words_per_chunk = int(chunk_size * 0.75)
|
||||
words_overlap = int(overlap * 0.75)
|
||||
|
||||
|
|
@ -36,25 +48,18 @@ def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]
|
|||
|
||||
async def get_embedding(text: str, client: httpx.AsyncClient | None = None) -> list[float]:
|
||||
"""Get embedding vector for text using Ollama."""
|
||||
should_close = False
|
||||
if client is None:
|
||||
client = httpx.AsyncClient(timeout=60)
|
||||
should_close = True
|
||||
c = client or _get_ollama_client()
|
||||
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{settings.ollama_base_url}/api/embeddings",
|
||||
json={
|
||||
"model": settings.ollama_embed_model,
|
||||
"prompt": text,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data["embedding"]
|
||||
finally:
|
||||
if should_close:
|
||||
await client.aclose()
|
||||
resp = await c.post(
|
||||
"/api/embeddings",
|
||||
json={
|
||||
"model": settings.ollama_embed_model,
|
||||
"prompt": text,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data["embedding"]
|
||||
|
||||
|
||||
async def get_embeddings_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:
|
||||
|
|
|
|||
12
app/llm.py
12
app/llm.py
|
|
@ -8,6 +8,8 @@ from app.config import settings
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_stream_timeout = httpx.Timeout(connect=10, read=300, write=10, pool=10)
|
||||
|
||||
|
||||
async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerator[str, None]:
|
||||
"""Stream a chat completion from Ollama."""
|
||||
|
|
@ -20,14 +22,16 @@ async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerato
|
|||
"model": settings.ollama_chat_model,
|
||||
"messages": all_messages,
|
||||
"stream": True,
|
||||
"keep_alive": "24h",
|
||||
"options": {
|
||||
"num_ctx": 2048,
|
||||
"num_predict": 512,
|
||||
"num_ctx": 1024,
|
||||
"num_predict": 256,
|
||||
"num_thread": 16,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}
|
||||
|
||||
timeout = httpx.Timeout(connect=30, read=600, write=30, pool=30)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with httpx.AsyncClient(timeout=_stream_timeout) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{settings.ollama_base_url}/api/chat",
|
||||
|
|
|
|||
28
app/main.py
28
app/main.py
|
|
@ -5,6 +5,7 @@ import uuid
|
|||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
|
@ -20,11 +21,38 @@ logger = logging.getLogger(__name__)
|
|||
sessions: dict[str, list[dict]] = {}
|
||||
|
||||
|
||||
async def warmup_models():
|
||||
"""Pre-load chat and embedding models into Ollama memory."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
# Warm up embedding model
|
||||
logger.info(f"Warming up embedding model: {settings.ollama_embed_model}")
|
||||
await client.post(
|
||||
f"{settings.ollama_base_url}/api/embeddings",
|
||||
json={"model": settings.ollama_embed_model, "prompt": "warmup"},
|
||||
)
|
||||
# Warm up chat model with keep_alive to hold in memory
|
||||
logger.info(f"Warming up chat model: {settings.ollama_chat_model}")
|
||||
await client.post(
|
||||
f"{settings.ollama_base_url}/api/generate",
|
||||
json={
|
||||
"model": settings.ollama_chat_model,
|
||||
"prompt": "hi",
|
||||
"options": {"num_predict": 1, "num_thread": 16},
|
||||
"keep_alive": "24h",
|
||||
},
|
||||
)
|
||||
logger.info("Models warmed up and loaded into memory.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model warmup failed (will load on first request): {e}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Initializing database...")
|
||||
await init_db()
|
||||
logger.info("Database ready.")
|
||||
await warmup_models()
|
||||
yield
|
||||
|
||||
|
||||
|
|
|
|||
10
app/rag.py
10
app/rag.py
|
|
@ -76,8 +76,7 @@ def build_context_prompt(chunks: list[dict]) -> str:
|
|||
header += f" | Substance: {metadata['substance']}"
|
||||
header += " ---"
|
||||
|
||||
# Limit each chunk to avoid overwhelming the LLM
|
||||
content = chunk["content"][:500]
|
||||
content = chunk["content"][:300]
|
||||
context_parts.append(f"{header}\n{content}")
|
||||
|
||||
return "\n\n".join(context_parts)
|
||||
|
|
@ -93,13 +92,12 @@ async def chat_stream(
|
|||
|
||||
# Build the context-augmented system prompt
|
||||
context_text = build_context_prompt(chunks)
|
||||
full_system = f"{SYSTEM_PROMPT}\n\n--- RELEVANT EROWID DATA ---\n{context_text}\n--- END EROWID DATA ---"
|
||||
full_system = f"{SYSTEM_PROMPT}\n\nContext:\n{context_text}"
|
||||
|
||||
# Build message history
|
||||
# Build message history (keep minimal for speed)
|
||||
messages = []
|
||||
if conversation_history:
|
||||
# Keep last 6 messages for context
|
||||
messages = conversation_history[-6:]
|
||||
messages = conversation_history[-4:]
|
||||
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue