erowid-bot/app/embeddings.py

205 lines
7.4 KiB
Python

import json
import logging
from typing import AsyncGenerator
import httpx
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import async_session
from app.models import Experience, Substance, DocumentChunk
logger = logging.getLogger(__name__)
# Persistent HTTP client for Ollama calls (avoids TCP handshake per request)
_ollama_client: httpx.AsyncClient | None = None
def _get_ollama_client() -> httpx.AsyncClient:
global _ollama_client
if _ollama_client is None or _ollama_client.is_closed:
_ollama_client = httpx.AsyncClient(
base_url=settings.ollama_base_url,
timeout=httpx.Timeout(connect=10, read=120, write=10, pool=10),
)
return _ollama_client
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
"""Split text into overlapping chunks by approximate token count (words / 0.75)."""
words = text.split()
words_per_chunk = int(chunk_size * 0.75)
words_overlap = int(overlap * 0.75)
if len(words) <= words_per_chunk:
return [text]
chunks = []
start = 0
while start < len(words):
end = start + words_per_chunk
chunk = " ".join(words[start:end])
chunks.append(chunk)
start = end - words_overlap
return chunks
async def get_embedding(text: str, client: httpx.AsyncClient | None = None) -> list[float]:
"""Get embedding vector for text using Ollama."""
c = client or _get_ollama_client()
resp = await c.post(
"/api/embeddings",
json={
"model": settings.ollama_embed_model,
"prompt": text,
},
)
resp.raise_for_status()
data = resp.json()
return data["embedding"]
async def get_embeddings_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:
"""Get embeddings for multiple texts sequentially."""
embeddings = []
for text in texts:
emb = await get_embedding(text, client)
embeddings.append(emb)
return embeddings
async def embed_experiences(batch_size: int = 20):
"""Chunk and embed all un-embedded experience reports."""
async with async_session() as db:
# Find experiences that don't have chunks yet
subq = select(DocumentChunk.source_id).where(
DocumentChunk.source_type == "experience"
).distinct()
result = await db.execute(
select(Experience).where(Experience.id.not_in(subq))
)
experiences = result.scalars().all()
logger.info(f"Found {len(experiences)} experiences to embed")
async with httpx.AsyncClient(timeout=60) as client:
total_chunks = 0
for i, exp in enumerate(experiences):
# Build a rich text representation
header = f"Experience Report: {exp.title}\n"
header += f"Substance: {exp.substance}\n"
if exp.category:
header += f"Category: {exp.category}\n"
if exp.gender:
header += f"Gender: {exp.gender}\n"
if exp.age:
header += f"Age: {exp.age}\n"
header += "\n"
full_text = header + exp.body
chunks = chunk_text(full_text, settings.chunk_size, settings.chunk_overlap)
for idx, chunk_text_content in enumerate(chunks):
embedding = await get_embedding(chunk_text_content, client)
metadata = json.dumps({
"title": exp.title,
"substance": exp.substance,
"category": exp.category,
"erowid_id": exp.erowid_id,
})
doc_chunk = DocumentChunk(
source_type="experience",
source_id=exp.id,
chunk_index=idx,
content=chunk_text_content,
metadata_json=metadata,
embedding=embedding,
)
db.add(doc_chunk)
total_chunks += 1
if (i + 1) % batch_size == 0:
await db.commit()
logger.info(f"Embedded {i + 1} experiences ({total_chunks} chunks)")
await db.commit()
logger.info(f"Done! Created {total_chunks} chunks from {len(experiences)} experiences")
return total_chunks
async def embed_substances(batch_size: int = 10):
"""Chunk and embed all un-embedded substance info pages."""
async with async_session() as db:
subq = select(DocumentChunk.source_id).where(
DocumentChunk.source_type == "substance"
).distinct()
result = await db.execute(
select(Substance).where(Substance.id.not_in(subq))
)
substances = result.scalars().all()
logger.info(f"Found {len(substances)} substances to embed")
async with httpx.AsyncClient(timeout=60) as client:
total_chunks = 0
for i, sub in enumerate(substances):
# Build rich text representation
sections = []
sections.append(f"Substance Information: {sub.name}")
if sub.category:
sections.append(f"Category: {sub.category}")
if sub.description:
sections.append(f"\nOverview:\n{sub.description}")
if sub.effects:
sections.append(f"\nEffects:\n{sub.effects}")
if sub.dosage:
sections.append(f"\nDosage:\n{sub.dosage}")
if sub.duration:
sections.append(f"\nDuration:\n{sub.duration}")
if sub.chemistry:
sections.append(f"\nChemistry:\n{sub.chemistry}")
if sub.health:
sections.append(f"\nHealth & Safety:\n{sub.health}")
if sub.law:
sections.append(f"\nLegal Status:\n{sub.law}")
full_text = "\n".join(sections)
chunks = chunk_text(full_text, settings.chunk_size, settings.chunk_overlap)
for idx, chunk_text_content in enumerate(chunks):
embedding = await get_embedding(chunk_text_content, client)
metadata = json.dumps({
"substance": sub.name,
"category": sub.category,
})
doc_chunk = DocumentChunk(
source_type="substance",
source_id=sub.id,
chunk_index=idx,
content=chunk_text_content,
metadata_json=metadata,
embedding=embedding,
)
db.add(doc_chunk)
total_chunks += 1
if (i + 1) % batch_size == 0:
await db.commit()
logger.info(f"Embedded {i + 1} substances ({total_chunks} chunks)")
await db.commit()
logger.info(f"Done! Created {total_chunks} chunks from {len(substances)} substances")
return total_chunks
async def embed_all():
"""Embed everything that hasn't been embedded yet."""
exp_chunks = await embed_experiences()
sub_chunks = await embed_substances()
return {"experience_chunks": exp_chunks, "substance_chunks": sub_chunks}