120 lines
3.9 KiB
Python
120 lines
3.9 KiB
Python
import json
|
|
import logging
|
|
from typing import AsyncGenerator
|
|
|
|
import httpx
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def stream_ollama(messages: list[dict], system: str = "") -> AsyncGenerator[str, None]:
|
|
"""Stream a chat completion from Ollama."""
|
|
all_messages = []
|
|
if system:
|
|
all_messages.append({"role": "system", "content": system})
|
|
all_messages.extend(messages)
|
|
|
|
payload = {
|
|
"model": settings.ollama_chat_model,
|
|
"messages": all_messages,
|
|
"stream": True,
|
|
}
|
|
|
|
timeout = httpx.Timeout(connect=30, read=600, write=30, pool=30)
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{settings.ollama_base_url}/api/chat",
|
|
json=payload,
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
buffer = b""
|
|
async for chunk in resp.aiter_bytes():
|
|
buffer += chunk
|
|
# Process complete JSON lines
|
|
while b"\n" in buffer:
|
|
line, buffer = buffer.split(b"\n", 1)
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
data = json.loads(line)
|
|
if "message" in data and "content" in data["message"]:
|
|
content = data["message"]["content"]
|
|
if content:
|
|
yield content
|
|
if data.get("done"):
|
|
return
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
|
|
async def stream_claude(messages: list[dict], system: str = "") -> AsyncGenerator[str, None]:
|
|
"""Stream a chat completion from Claude API."""
|
|
try:
|
|
from anthropic import AsyncAnthropic
|
|
except ImportError:
|
|
raise RuntimeError("anthropic package not installed")
|
|
|
|
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
|
|
|
async with client.messages.stream(
|
|
model="claude-sonnet-4-5-20250929",
|
|
max_tokens=2048,
|
|
system=system,
|
|
messages=messages,
|
|
) as stream:
|
|
async for text in stream.text_stream:
|
|
yield text
|
|
|
|
|
|
async def stream_openai(messages: list[dict], system: str = "") -> AsyncGenerator[str, None]:
|
|
"""Stream a chat completion from OpenAI API."""
|
|
try:
|
|
from openai import AsyncOpenAI
|
|
except ImportError:
|
|
raise RuntimeError("openai package not installed")
|
|
|
|
client = AsyncOpenAI(api_key=settings.openai_api_key)
|
|
|
|
all_messages = []
|
|
if system:
|
|
all_messages.append({"role": "system", "content": system})
|
|
all_messages.extend(messages)
|
|
|
|
stream = await client.chat.completions.create(
|
|
model="gpt-4o-mini",
|
|
messages=all_messages,
|
|
max_tokens=2048,
|
|
stream=True,
|
|
)
|
|
|
|
async for chunk in stream:
|
|
if chunk.choices and chunk.choices[0].delta.content:
|
|
yield chunk.choices[0].delta.content
|
|
|
|
|
|
async def stream_chat(messages: list[dict], system: str = "") -> AsyncGenerator[str, None]:
|
|
"""Route to the configured LLM provider."""
|
|
provider = settings.llm_provider.lower()
|
|
|
|
if provider == "ollama":
|
|
async for token in stream_ollama(messages, system):
|
|
yield token
|
|
elif provider == "claude":
|
|
if not settings.anthropic_api_key:
|
|
yield "Error: ANTHROPIC_API_KEY not configured. Set it in .env or switch LLM_PROVIDER to ollama."
|
|
return
|
|
async for token in stream_claude(messages, system):
|
|
yield token
|
|
elif provider == "openai":
|
|
if not settings.openai_api_key:
|
|
yield "Error: OPENAI_API_KEY not configured. Set it in .env or switch LLM_PROVIDER to ollama."
|
|
return
|
|
async for token in stream_openai(messages, system):
|
|
yield token
|
|
else:
|
|
yield f"Error: Unknown LLM_PROVIDER '{provider}'. Use ollama, claude, or openai."
|