109 lines
3.5 KiB
Python
109 lines
3.5 KiB
Python
import json
|
|
import logging
|
|
from typing import AsyncGenerator
|
|
|
|
from sqlalchemy import select, text
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.config import settings
|
|
from app.database import async_session
|
|
from app.embeddings import get_embedding
|
|
from app.llm import stream_chat
|
|
from app.models import DocumentChunk
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SYSTEM_PROMPT = """Erowid harm-reduction assistant. Give accurate, non-judgmental substance info from the provided context. Prioritize safety. Never encourage use. Be concise."""
|
|
|
|
|
|
async def retrieve_context(query: str, top_k: int | None = None) -> list[dict]:
|
|
"""Retrieve the most relevant document chunks for a query."""
|
|
if top_k is None:
|
|
top_k = settings.retrieval_top_k
|
|
|
|
# Get query embedding
|
|
query_embedding = await get_embedding(query)
|
|
|
|
async with async_session() as db:
|
|
# Use pgvector cosine distance for similarity search
|
|
result = await db.execute(
|
|
text("""
|
|
SELECT id, source_type, source_id, chunk_index, content, metadata_json,
|
|
embedding <=> :query_embedding AS distance
|
|
FROM document_chunks
|
|
ORDER BY embedding <=> :query_embedding
|
|
LIMIT :top_k
|
|
"""),
|
|
{"query_embedding": str(query_embedding), "top_k": top_k},
|
|
)
|
|
|
|
chunks = []
|
|
for row in result.fetchall():
|
|
metadata = {}
|
|
if row[5]:
|
|
try:
|
|
metadata = json.loads(row[5])
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
chunks.append({
|
|
"id": row[0],
|
|
"source_type": row[1],
|
|
"source_id": row[2],
|
|
"chunk_index": row[3],
|
|
"content": row[4],
|
|
"metadata": metadata,
|
|
"distance": row[6],
|
|
})
|
|
|
|
return chunks
|
|
|
|
|
|
def build_context_prompt(chunks: list[dict]) -> str:
|
|
"""Build a context string from retrieved chunks."""
|
|
if not chunks:
|
|
return "\n[No relevant documents found in the database.]\n"
|
|
|
|
context_parts = []
|
|
for i, chunk in enumerate(chunks, 1):
|
|
source_label = chunk["source_type"].title()
|
|
metadata = chunk["metadata"]
|
|
|
|
header = f"--- Source {i} ({source_label})"
|
|
if "title" in metadata:
|
|
header += f" | {metadata['title']}"
|
|
if "substance" in metadata:
|
|
header += f" | Substance: {metadata['substance']}"
|
|
header += " ---"
|
|
|
|
# Limit each chunk to avoid overwhelming the LLM
|
|
content = chunk["content"][:500]
|
|
context_parts.append(f"{header}\n{content}")
|
|
|
|
return "\n\n".join(context_parts)
|
|
|
|
|
|
async def chat_stream(
|
|
user_message: str,
|
|
conversation_history: list[dict] | None = None,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Full RAG pipeline: retrieve context, build prompt, stream LLM response."""
|
|
# Retrieve relevant chunks
|
|
chunks = await retrieve_context(user_message)
|
|
|
|
# Build the context-augmented system prompt
|
|
context_text = build_context_prompt(chunks)
|
|
full_system = f"{SYSTEM_PROMPT}\n\n--- RELEVANT EROWID DATA ---\n{context_text}\n--- END EROWID DATA ---"
|
|
|
|
# Build message history
|
|
messages = []
|
|
if conversation_history:
|
|
# Keep last 6 messages for context
|
|
messages = conversation_history[-6:]
|
|
|
|
messages.append({"role": "user", "content": user_message})
|
|
|
|
# Stream from LLM
|
|
async for token in stream_chat(messages, system=full_system):
|
|
yield token
|