141 lines
4.4 KiB
Python
141 lines
4.4 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from app.config import settings
|
|
from app.database import init_db, async_session
|
|
from app.rag import chat_stream
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# In-memory session store (conversation history per session)
|
|
sessions: dict[str, list[dict]] = {}
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info("Initializing database...")
|
|
await init_db()
|
|
logger.info("Database ready.")
|
|
yield
|
|
|
|
|
|
app = FastAPI(title="Erowid Bot", lifespan=lifespan)
|
|
|
|
# Serve static files
|
|
static_dir = Path(__file__).parent / "static"
|
|
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def index():
|
|
return (static_dir / "index.html").read_text()
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.post("/chat")
|
|
async def chat(request: Request):
|
|
"""Chat endpoint with streaming SSE response."""
|
|
body = await request.json()
|
|
message = body.get("message", "").strip()
|
|
session_id = body.get("session_id", "")
|
|
|
|
if not message:
|
|
return JSONResponse({"error": "Empty message"}, status_code=400)
|
|
|
|
if not session_id:
|
|
session_id = str(uuid.uuid4())
|
|
|
|
# Get or create conversation history
|
|
history = sessions.get(session_id, [])
|
|
|
|
async def generate():
|
|
full_response = ""
|
|
try:
|
|
async for token in chat_stream(message, history):
|
|
full_response += token
|
|
yield f"data: {json.dumps({'token': token})}\n\n"
|
|
except Exception as e:
|
|
logger.error(f"Chat error: {e}")
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
# Save to history
|
|
history.append({"role": "user", "content": message})
|
|
history.append({"role": "assistant", "content": full_response})
|
|
# Keep history bounded
|
|
if len(history) > 20:
|
|
history[:] = history[-20:]
|
|
sessions[session_id] = history
|
|
|
|
yield f"data: {json.dumps({'done': True, 'session_id': session_id})}\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
@app.get("/stats")
|
|
async def stats():
|
|
"""Get database stats."""
|
|
from sqlalchemy import func, select
|
|
from app.models import Experience, Substance, DocumentChunk
|
|
|
|
async with async_session() as db:
|
|
exp_count = (await db.execute(select(func.count(Experience.id)))).scalar() or 0
|
|
sub_count = (await db.execute(select(func.count(Substance.id)))).scalar() or 0
|
|
chunk_count = (await db.execute(select(func.count(DocumentChunk.id)))).scalar() or 0
|
|
|
|
return {
|
|
"experiences": exp_count,
|
|
"substances": sub_count,
|
|
"chunks": chunk_count,
|
|
}
|
|
|
|
|
|
@app.post("/admin/scrape/experiences")
|
|
async def trigger_scrape_experiences(request: Request):
|
|
"""Trigger experience scraping (admin endpoint)."""
|
|
body = await request.json() if request.headers.get("content-type") == "application/json" else {}
|
|
limit = body.get("limit")
|
|
|
|
from app.scraper.experiences import scrape_all_experiences
|
|
asyncio.create_task(scrape_all_experiences(limit=limit))
|
|
return {"status": "started", "message": "Experience scraping started in background"}
|
|
|
|
|
|
@app.post("/admin/scrape/substances")
|
|
async def trigger_scrape_substances(request: Request):
|
|
"""Trigger substance scraping (admin endpoint)."""
|
|
body = await request.json() if request.headers.get("content-type") == "application/json" else {}
|
|
limit = body.get("limit")
|
|
|
|
from app.scraper.substances import scrape_all_substances
|
|
asyncio.create_task(scrape_all_substances(limit=limit))
|
|
return {"status": "started", "message": "Substance scraping started in background"}
|
|
|
|
|
|
@app.post("/admin/embed")
|
|
async def trigger_embedding():
|
|
"""Trigger embedding pipeline (admin endpoint)."""
|
|
from app.embeddings import embed_all
|
|
asyncio.create_task(embed_all())
|
|
return {"status": "started", "message": "Embedding pipeline started in background"}
|