import asyncio import json import logging import uuid from contextlib import asynccontextmanager from pathlib import Path from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse from fastapi.staticfiles import StaticFiles from app.config import settings from app.database import init_db, async_session from app.rag import chat_stream logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s") logger = logging.getLogger(__name__) # In-memory session store (conversation history per session) sessions: dict[str, list[dict]] = {} @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Initializing database...") await init_db() logger.info("Database ready.") yield app = FastAPI(title="Erowid Bot", lifespan=lifespan) # Serve static files static_dir = Path(__file__).parent / "static" app.mount("/static", StaticFiles(directory=str(static_dir)), name="static") @app.get("/", response_class=HTMLResponse) async def index(): return (static_dir / "index.html").read_text() @app.get("/health") async def health(): return {"status": "ok"} @app.post("/chat") async def chat(request: Request): """Chat endpoint with streaming SSE response.""" body = await request.json() message = body.get("message", "").strip() session_id = body.get("session_id", "") if not message: return JSONResponse({"error": "Empty message"}, status_code=400) if not session_id: session_id = str(uuid.uuid4()) # Get or create conversation history history = sessions.get(session_id, []) async def generate(): full_response = "" try: async for token in chat_stream(message, history): full_response += token yield f"data: {json.dumps({'token': token})}\n\n" except Exception as e: logger.error(f"Chat error: {e}") yield f"data: {json.dumps({'error': str(e)})}\n\n" # Save to history history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": full_response}) # Keep history bounded if len(history) > 20: history[:] = history[-20:] sessions[session_id] = history yield f"data: {json.dumps({'done': True, 'session_id': session_id})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @app.get("/stats") async def stats(): """Get database stats.""" from sqlalchemy import func, select from app.models import Experience, Substance, DocumentChunk async with async_session() as db: exp_count = (await db.execute(select(func.count(Experience.id)))).scalar() or 0 sub_count = (await db.execute(select(func.count(Substance.id)))).scalar() or 0 chunk_count = (await db.execute(select(func.count(DocumentChunk.id)))).scalar() or 0 return { "experiences": exp_count, "substances": sub_count, "chunks": chunk_count, } @app.post("/admin/scrape/experiences") async def trigger_scrape_experiences(request: Request): """Trigger experience scraping (admin endpoint).""" body = await request.json() if request.headers.get("content-type") == "application/json" else {} limit = body.get("limit") from app.scraper.experiences import scrape_all_experiences asyncio.create_task(scrape_all_experiences(limit=limit)) return {"status": "started", "message": "Experience scraping started in background"} @app.post("/admin/scrape/substances") async def trigger_scrape_substances(request: Request): """Trigger substance scraping (admin endpoint).""" body = await request.json() if request.headers.get("content-type") == "application/json" else {} limit = body.get("limit") from app.scraper.substances import scrape_all_substances asyncio.create_task(scrape_all_substances(limit=limit)) return {"status": "started", "message": "Substance scraping started in background"} @app.post("/admin/embed") async def trigger_embedding(): """Trigger embedding pipeline (admin endpoint).""" from app.embeddings import embed_all asyncio.create_task(embed_all()) return {"status": "started", "message": "Embedding pipeline started in background"}