import asyncio import json import logging import uuid from contextlib import asynccontextmanager from pathlib import Path import httpx from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse from fastapi.staticfiles import StaticFiles from app.config import settings from app.database import init_db, async_session from app.rag import chat_stream logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s") logger = logging.getLogger(__name__) # In-memory session store (conversation history per session) sessions: dict[str, list[dict]] = {} async def warmup_models(): """Pre-load chat and embedding models into Ollama memory.""" try: async with httpx.AsyncClient(timeout=120) as client: # Warm up embedding model logger.info(f"Warming up embedding model: {settings.ollama_embed_model}") await client.post( f"{settings.ollama_base_url}/api/embeddings", json={"model": settings.ollama_embed_model, "prompt": "warmup"}, ) # Warm up chat model with keep_alive to hold in memory logger.info(f"Warming up chat model: {settings.ollama_chat_model}") await client.post( f"{settings.ollama_base_url}/api/generate", json={ "model": settings.ollama_chat_model, "prompt": "hi", "options": {"num_predict": 1, "num_thread": 16}, "keep_alive": "24h", }, ) logger.info("Models warmed up and loaded into memory.") except Exception as e: logger.warning(f"Model warmup failed (will load on first request): {e}") @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Initializing database...") await init_db() logger.info("Database ready.") await warmup_models() yield app = FastAPI(title="Erowid Bot", lifespan=lifespan) # Serve static files static_dir = Path(__file__).parent / "static" app.mount("/static", StaticFiles(directory=str(static_dir)), name="static") @app.get("/", response_class=HTMLResponse) async def index(): return (static_dir / "index.html").read_text() @app.get("/health") async def health(): return {"status": "ok"} @app.post("/chat") async def chat(request: Request): """Chat endpoint with streaming SSE response.""" body = await request.json() message = body.get("message", "").strip() session_id = body.get("session_id", "") if not message: return JSONResponse({"error": "Empty message"}, status_code=400) if not session_id: session_id = str(uuid.uuid4()) # Get or create conversation history history = sessions.get(session_id, []) async def generate(): full_response = "" try: async for token in chat_stream(message, history): full_response += token yield f"data: {json.dumps({'token': token})}\n\n" except Exception as e: logger.error(f"Chat error: {e}") yield f"data: {json.dumps({'error': str(e)})}\n\n" # Save to history history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": full_response}) # Keep history bounded if len(history) > 20: history[:] = history[-20:] sessions[session_id] = history yield f"data: {json.dumps({'done': True, 'session_id': session_id})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @app.get("/stats") async def stats(): """Get database stats.""" from sqlalchemy import func, select from app.models import Experience, Substance, DocumentChunk async with async_session() as db: exp_count = (await db.execute(select(func.count(Experience.id)))).scalar() or 0 sub_count = (await db.execute(select(func.count(Substance.id)))).scalar() or 0 chunk_count = (await db.execute(select(func.count(DocumentChunk.id)))).scalar() or 0 return { "experiences": exp_count, "substances": sub_count, "chunks": chunk_count, } @app.post("/admin/scrape/experiences") async def trigger_scrape_experiences(request: Request): """Trigger experience scraping (admin endpoint).""" body = await request.json() if request.headers.get("content-type") == "application/json" else {} limit = body.get("limit") from app.scraper.experiences import scrape_all_experiences asyncio.create_task(scrape_all_experiences(limit=limit)) return {"status": "started", "message": "Experience scraping started in background"} @app.post("/admin/scrape/substances") async def trigger_scrape_substances(request: Request): """Trigger substance scraping (admin endpoint).""" body = await request.json() if request.headers.get("content-type") == "application/json" else {} limit = body.get("limit") from app.scraper.substances import scrape_all_substances asyncio.create_task(scrape_all_substances(limit=limit)) return {"status": "started", "message": "Substance scraping started in background"} @app.post("/admin/embed") async def trigger_embedding(): """Trigger embedding pipeline (admin endpoint).""" from app.embeddings import embed_all asyncio.create_task(embed_all()) return {"status": "started", "message": "Embedding pipeline started in background"}