erowid-bot/app/rag.py

107 lines
3.4 KiB
Python

import json
import logging
from typing import AsyncGenerator
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import async_session
from app.embeddings import get_embedding
from app.llm import stream_chat
from app.models import DocumentChunk
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """Erowid harm-reduction assistant. Give accurate, non-judgmental substance info from the provided context. Prioritize safety. Never encourage use. Be concise."""
async def retrieve_context(query: str, top_k: int | None = None) -> list[dict]:
"""Retrieve the most relevant document chunks for a query."""
if top_k is None:
top_k = settings.retrieval_top_k
# Get query embedding
query_embedding = await get_embedding(query)
async with async_session() as db:
# Use pgvector cosine distance for similarity search
result = await db.execute(
text("""
SELECT id, source_type, source_id, chunk_index, content, metadata_json,
embedding <=> :query_embedding AS distance
FROM document_chunks
ORDER BY embedding <=> :query_embedding
LIMIT :top_k
"""),
{"query_embedding": str(query_embedding), "top_k": top_k},
)
chunks = []
for row in result.fetchall():
metadata = {}
if row[5]:
try:
metadata = json.loads(row[5])
except json.JSONDecodeError:
pass
chunks.append({
"id": row[0],
"source_type": row[1],
"source_id": row[2],
"chunk_index": row[3],
"content": row[4],
"metadata": metadata,
"distance": row[6],
})
return chunks
def build_context_prompt(chunks: list[dict]) -> str:
"""Build a context string from retrieved chunks."""
if not chunks:
return "\n[No relevant documents found in the database.]\n"
context_parts = []
for i, chunk in enumerate(chunks, 1):
source_label = chunk["source_type"].title()
metadata = chunk["metadata"]
header = f"--- Source {i} ({source_label})"
if "title" in metadata:
header += f" | {metadata['title']}"
if "substance" in metadata:
header += f" | Substance: {metadata['substance']}"
header += " ---"
content = chunk["content"][:300]
context_parts.append(f"{header}\n{content}")
return "\n\n".join(context_parts)
async def chat_stream(
user_message: str,
conversation_history: list[dict] | None = None,
) -> AsyncGenerator[str, None]:
"""Full RAG pipeline: retrieve context, build prompt, stream LLM response."""
# Retrieve relevant chunks
chunks = await retrieve_context(user_message)
# Build the context-augmented system prompt
context_text = build_context_prompt(chunks)
full_system = f"{SYSTEM_PROMPT}\n\nContext:\n{context_text}"
# Build message history (keep minimal for speed)
messages = []
if conversation_history:
messages = conversation_history[-4:]
messages.append({"role": "user", "content": user_message})
# Stream from LLM
async for token in stream_chat(messages, system=full_system):
yield token