import json import logging from typing import AsyncGenerator from sqlalchemy import select, text from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.database import async_session from app.embeddings import get_embedding from app.llm import stream_chat from app.models import DocumentChunk logger = logging.getLogger(__name__) SYSTEM_PROMPT = """You are the Erowid Knowledge Assistant focused on harm reduction. Provide accurate, non-judgmental substance info from the Erowid database. Prioritize safety. Never encourage drug use. Cite sources when possible. Say when info is limited.""" async def retrieve_context(query: str, top_k: int | None = None) -> list[dict]: """Retrieve the most relevant document chunks for a query.""" if top_k is None: top_k = settings.retrieval_top_k # Get query embedding query_embedding = await get_embedding(query) async with async_session() as db: # Use pgvector cosine distance for similarity search result = await db.execute( text(""" SELECT id, source_type, source_id, chunk_index, content, metadata_json, embedding <=> :query_embedding AS distance FROM document_chunks ORDER BY embedding <=> :query_embedding LIMIT :top_k """), {"query_embedding": str(query_embedding), "top_k": top_k}, ) chunks = [] for row in result.fetchall(): metadata = {} if row[5]: try: metadata = json.loads(row[5]) except json.JSONDecodeError: pass chunks.append({ "id": row[0], "source_type": row[1], "source_id": row[2], "chunk_index": row[3], "content": row[4], "metadata": metadata, "distance": row[6], }) return chunks def build_context_prompt(chunks: list[dict]) -> str: """Build a context string from retrieved chunks.""" if not chunks: return "\n[No relevant documents found in the database.]\n" context_parts = [] for i, chunk in enumerate(chunks, 1): source_label = chunk["source_type"].title() metadata = chunk["metadata"] header = f"--- Source {i} ({source_label})" if "title" in metadata: header += f" | {metadata['title']}" if "substance" in metadata: header += f" | Substance: {metadata['substance']}" header += " ---" # Limit each chunk to avoid overwhelming the LLM content = chunk["content"][:800] context_parts.append(f"{header}\n{content}") return "\n\n".join(context_parts) async def chat_stream( user_message: str, conversation_history: list[dict] | None = None, ) -> AsyncGenerator[str, None]: """Full RAG pipeline: retrieve context, build prompt, stream LLM response.""" # Retrieve relevant chunks chunks = await retrieve_context(user_message) # Build the context-augmented system prompt context_text = build_context_prompt(chunks) full_system = f"{SYSTEM_PROMPT}\n\n--- RELEVANT EROWID DATA ---\n{context_text}\n--- END EROWID DATA ---" # Build message history messages = [] if conversation_history: # Keep last 10 messages for context messages = conversation_history[-10:] messages.append({"role": "user", "content": user_message}) # Stream from LLM async for token in stream_chat(messages, system=full_system): yield token