erowid-bot/app/main.py

169 lines
5.6 KiB
Python

import asyncio
import json
import logging
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from app.config import settings
from app.database import init_db, async_session
from app.rag import chat_stream
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
# In-memory session store (conversation history per session)
sessions: dict[str, list[dict]] = {}
async def warmup_models():
"""Pre-load chat and embedding models into Ollama memory."""
try:
async with httpx.AsyncClient(timeout=120) as client:
# Warm up embedding model
logger.info(f"Warming up embedding model: {settings.ollama_embed_model}")
await client.post(
f"{settings.ollama_base_url}/api/embeddings",
json={"model": settings.ollama_embed_model, "prompt": "warmup"},
)
# Warm up chat model with keep_alive to hold in memory
logger.info(f"Warming up chat model: {settings.ollama_chat_model}")
await client.post(
f"{settings.ollama_base_url}/api/generate",
json={
"model": settings.ollama_chat_model,
"prompt": "hi",
"options": {"num_predict": 1, "num_thread": 16},
"keep_alive": "24h",
},
)
logger.info("Models warmed up and loaded into memory.")
except Exception as e:
logger.warning(f"Model warmup failed (will load on first request): {e}")
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Initializing database...")
await init_db()
logger.info("Database ready.")
await warmup_models()
yield
app = FastAPI(title="Erowid Bot", lifespan=lifespan)
# Serve static files
static_dir = Path(__file__).parent / "static"
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
@app.get("/", response_class=HTMLResponse)
async def index():
return (static_dir / "index.html").read_text()
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/chat")
async def chat(request: Request):
"""Chat endpoint with streaming SSE response."""
body = await request.json()
message = body.get("message", "").strip()
session_id = body.get("session_id", "")
if not message:
return JSONResponse({"error": "Empty message"}, status_code=400)
if not session_id:
session_id = str(uuid.uuid4())
# Get or create conversation history
history = sessions.get(session_id, [])
async def generate():
full_response = ""
try:
async for token in chat_stream(message, history):
full_response += token
yield f"data: {json.dumps({'token': token})}\n\n"
except Exception as e:
logger.error(f"Chat error: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
# Save to history
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": full_response})
# Keep history bounded
if len(history) > 20:
history[:] = history[-20:]
sessions[session_id] = history
yield f"data: {json.dumps({'done': True, 'session_id': session_id})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@app.get("/stats")
async def stats():
"""Get database stats."""
from sqlalchemy import func, select
from app.models import Experience, Substance, DocumentChunk
async with async_session() as db:
exp_count = (await db.execute(select(func.count(Experience.id)))).scalar() or 0
sub_count = (await db.execute(select(func.count(Substance.id)))).scalar() or 0
chunk_count = (await db.execute(select(func.count(DocumentChunk.id)))).scalar() or 0
return {
"experiences": exp_count,
"substances": sub_count,
"chunks": chunk_count,
}
@app.post("/admin/scrape/experiences")
async def trigger_scrape_experiences(request: Request):
"""Trigger experience scraping (admin endpoint)."""
body = await request.json() if request.headers.get("content-type") == "application/json" else {}
limit = body.get("limit")
from app.scraper.experiences import scrape_all_experiences
asyncio.create_task(scrape_all_experiences(limit=limit))
return {"status": "started", "message": "Experience scraping started in background"}
@app.post("/admin/scrape/substances")
async def trigger_scrape_substances(request: Request):
"""Trigger substance scraping (admin endpoint)."""
body = await request.json() if request.headers.get("content-type") == "application/json" else {}
limit = body.get("limit")
from app.scraper.substances import scrape_all_substances
asyncio.create_task(scrape_all_substances(limit=limit))
return {"status": "started", "message": "Substance scraping started in background"}
@app.post("/admin/embed")
async def trigger_embedding():
"""Trigger embedding pipeline (admin endpoint)."""
from app.embeddings import embed_all
asyncio.create_task(embed_all())
return {"status": "started", "message": "Embedding pipeline started in background"}