Aggressively optimize Ollama CPU inference speed
- Warm up both models on startup with keep_alive=24h (no cold starts) - Use 16 threads for inference (server has 20 cores) - Reduce context window to 1024 tokens, max output to 256 - Persistent httpx client for embedding calls (skip TCP handshake) - Trim RAG chunks to 300 chars, history to 4 messages - Shorter system prompt and context wrapper Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3215283f97
commit
08be7716f9
|
|
@ -12,11 +12,23 @@ from app.models import Experience, Substance, DocumentChunk
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Persistent HTTP client for Ollama calls (avoids TCP handshake per request)
|
||||||
|
_ollama_client: httpx.AsyncClient | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ollama_client() -> httpx.AsyncClient:
|
||||||
|
global _ollama_client
|
||||||
|
if _ollama_client is None or _ollama_client.is_closed:
|
||||||
|
_ollama_client = httpx.AsyncClient(
|
||||||
|
base_url=settings.ollama_base_url,
|
||||||
|
timeout=httpx.Timeout(connect=10, read=120, write=10, pool=10),
|
||||||
|
)
|
||||||
|
return _ollama_client
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
|
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
|
||||||
"""Split text into overlapping chunks by approximate token count (words / 0.75)."""
|
"""Split text into overlapping chunks by approximate token count (words / 0.75)."""
|
||||||
words = text.split()
|
words = text.split()
|
||||||
# Approximate: 1 token ~ 0.75 words
|
|
||||||
words_per_chunk = int(chunk_size * 0.75)
|
words_per_chunk = int(chunk_size * 0.75)
|
||||||
words_overlap = int(overlap * 0.75)
|
words_overlap = int(overlap * 0.75)
|
||||||
|
|
||||||
|
|
@ -36,25 +48,18 @@ def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]
|
||||||
|
|
||||||
async def get_embedding(text: str, client: httpx.AsyncClient | None = None) -> list[float]:
|
async def get_embedding(text: str, client: httpx.AsyncClient | None = None) -> list[float]:
|
||||||
"""Get embedding vector for text using Ollama."""
|
"""Get embedding vector for text using Ollama."""
|
||||||
should_close = False
|
c = client or _get_ollama_client()
|
||||||
if client is None:
|
|
||||||
client = httpx.AsyncClient(timeout=60)
|
|
||||||
should_close = True
|
|
||||||
|
|
||||||
try:
|
resp = await c.post(
|
||||||
resp = await client.post(
|
"/api/embeddings",
|
||||||
f"{settings.ollama_base_url}/api/embeddings",
|
json={
|
||||||
json={
|
"model": settings.ollama_embed_model,
|
||||||
"model": settings.ollama_embed_model,
|
"prompt": text,
|
||||||
"prompt": text,
|
},
|
||||||
},
|
)
|
||||||
)
|
resp.raise_for_status()
|
||||||
resp.raise_for_status()
|
data = resp.json()
|
||||||
data = resp.json()
|
return data["embedding"]
|
||||||
return data["embedding"]
|
|
||||||
finally:
|
|
||||||
if should_close:
|
|
||||||
await client.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
async def get_embeddings_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:
|
async def get_embeddings_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:
|
||||||
|
|
|
||||||
12
app/llm.py
12
app/llm.py
|
|
@ -8,6 +8,8 @@ from app.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_stream_timeout = httpx.Timeout(connect=10, read=300, write=10, pool=10)
|
||||||
|
|
||||||
|
|
||||||
async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerator[str, None]:
|
async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerator[str, None]:
|
||||||
"""Stream a chat completion from Ollama."""
|
"""Stream a chat completion from Ollama."""
|
||||||
|
|
@ -20,14 +22,16 @@ async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerato
|
||||||
"model": settings.ollama_chat_model,
|
"model": settings.ollama_chat_model,
|
||||||
"messages": all_messages,
|
"messages": all_messages,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
|
"keep_alive": "24h",
|
||||||
"options": {
|
"options": {
|
||||||
"num_ctx": 2048,
|
"num_ctx": 1024,
|
||||||
"num_predict": 512,
|
"num_predict": 256,
|
||||||
|
"num_thread": 16,
|
||||||
|
"temperature": 0.7,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout = httpx.Timeout(connect=30, read=600, write=30, pool=30)
|
async with httpx.AsyncClient(timeout=_stream_timeout) as client:
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"{settings.ollama_base_url}/api/chat",
|
f"{settings.ollama_base_url}/api/chat",
|
||||||
|
|
|
||||||
28
app/main.py
28
app/main.py
|
|
@ -5,6 +5,7 @@ import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
@ -20,11 +21,38 @@ logger = logging.getLogger(__name__)
|
||||||
sessions: dict[str, list[dict]] = {}
|
sessions: dict[str, list[dict]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def warmup_models():
|
||||||
|
"""Pre-load chat and embedding models into Ollama memory."""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
# Warm up embedding model
|
||||||
|
logger.info(f"Warming up embedding model: {settings.ollama_embed_model}")
|
||||||
|
await client.post(
|
||||||
|
f"{settings.ollama_base_url}/api/embeddings",
|
||||||
|
json={"model": settings.ollama_embed_model, "prompt": "warmup"},
|
||||||
|
)
|
||||||
|
# Warm up chat model with keep_alive to hold in memory
|
||||||
|
logger.info(f"Warming up chat model: {settings.ollama_chat_model}")
|
||||||
|
await client.post(
|
||||||
|
f"{settings.ollama_base_url}/api/generate",
|
||||||
|
json={
|
||||||
|
"model": settings.ollama_chat_model,
|
||||||
|
"prompt": "hi",
|
||||||
|
"options": {"num_predict": 1, "num_thread": 16},
|
||||||
|
"keep_alive": "24h",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info("Models warmed up and loaded into memory.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Model warmup failed (will load on first request): {e}")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
logger.info("Initializing database...")
|
logger.info("Initializing database...")
|
||||||
await init_db()
|
await init_db()
|
||||||
logger.info("Database ready.")
|
logger.info("Database ready.")
|
||||||
|
await warmup_models()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
10
app/rag.py
10
app/rag.py
|
|
@ -76,8 +76,7 @@ def build_context_prompt(chunks: list[dict]) -> str:
|
||||||
header += f" | Substance: {metadata['substance']}"
|
header += f" | Substance: {metadata['substance']}"
|
||||||
header += " ---"
|
header += " ---"
|
||||||
|
|
||||||
# Limit each chunk to avoid overwhelming the LLM
|
content = chunk["content"][:300]
|
||||||
content = chunk["content"][:500]
|
|
||||||
context_parts.append(f"{header}\n{content}")
|
context_parts.append(f"{header}\n{content}")
|
||||||
|
|
||||||
return "\n\n".join(context_parts)
|
return "\n\n".join(context_parts)
|
||||||
|
|
@ -93,13 +92,12 @@ async def chat_stream(
|
||||||
|
|
||||||
# Build the context-augmented system prompt
|
# Build the context-augmented system prompt
|
||||||
context_text = build_context_prompt(chunks)
|
context_text = build_context_prompt(chunks)
|
||||||
full_system = f"{SYSTEM_PROMPT}\n\n--- RELEVANT EROWID DATA ---\n{context_text}\n--- END EROWID DATA ---"
|
full_system = f"{SYSTEM_PROMPT}\n\nContext:\n{context_text}"
|
||||||
|
|
||||||
# Build message history
|
# Build message history (keep minimal for speed)
|
||||||
messages = []
|
messages = []
|
||||||
if conversation_history:
|
if conversation_history:
|
||||||
# Keep last 6 messages for context
|
messages = conversation_history[-4:]
|
||||||
messages = conversation_history[-6:]
|
|
||||||
|
|
||||||
messages.append({"role": "user", "content": user_message})
|
messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue