diff --git a/app/config.py b/app/config.py index 92ac0fc..be5efea 100644 --- a/app/config.py +++ b/app/config.py @@ -9,7 +9,7 @@ class Settings(BaseSettings): ollama_base_url: str = "http://ollama:11434" ollama_embed_model: str = "nomic-embed-text" - ollama_chat_model: str = "llama3.1:8b" + ollama_chat_model: str = "llama3.2:1b" anthropic_api_key: str = "" openai_api_key: str = "" @@ -24,7 +24,7 @@ class Settings(BaseSettings): # RAG settings chunk_size: int = 500 # tokens per chunk chunk_overlap: int = 50 # token overlap between chunks - retrieval_top_k: int = 4 # number of chunks to retrieve + retrieval_top_k: int = 3 # number of chunks to retrieve class Config: env_file = ".env" diff --git a/app/llm.py b/app/llm.py index b033a50..11e467f 100644 --- a/app/llm.py +++ b/app/llm.py @@ -20,6 +20,10 @@ async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerato "model": settings.ollama_chat_model, "messages": all_messages, "stream": True, + "options": { + "num_ctx": 2048, + "num_predict": 512, + }, } timeout = httpx.Timeout(connect=30, read=600, write=30, pool=30) diff --git a/app/rag.py b/app/rag.py index a855e12..25cfd4d 100644 --- a/app/rag.py +++ b/app/rag.py @@ -13,7 +13,7 @@ from app.models import DocumentChunk logger = logging.getLogger(__name__) -SYSTEM_PROMPT = """You are the Erowid Knowledge Assistant focused on harm reduction. Provide accurate, non-judgmental substance info from the Erowid database. Prioritize safety. Never encourage drug use. Cite sources when possible. Say when info is limited.""" +SYSTEM_PROMPT = """Erowid harm-reduction assistant. Give accurate, non-judgmental substance info from the provided context. Prioritize safety. Never encourage use. Be concise.""" async def retrieve_context(query: str, top_k: int | None = None) -> list[dict]: @@ -77,7 +77,7 @@ def build_context_prompt(chunks: list[dict]) -> str: header += " ---" # Limit each chunk to avoid overwhelming the LLM - content = chunk["content"][:800] + content = chunk["content"][:500] context_parts.append(f"{header}\n{content}") return "\n\n".join(context_parts) @@ -98,8 +98,8 @@ async def chat_stream( # Build message history messages = [] if conversation_history: - # Keep last 10 messages for context - messages = conversation_history[-10:] + # Keep last 6 messages for context + messages = conversation_history[-6:] messages.append({"role": "user", "content": user_message})