erowid-bot/app/embeddings.py

200 lines
7.3 KiB
Python

import json
import logging
from typing import AsyncGenerator
import httpx
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import async_session
from app.models import Experience, Substance, DocumentChunk
logger = logging.getLogger(__name__)
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
"""Split text into overlapping chunks by approximate token count (words / 0.75)."""
words = text.split()
# Approximate: 1 token ~ 0.75 words
words_per_chunk = int(chunk_size * 0.75)
words_overlap = int(overlap * 0.75)
if len(words) <= words_per_chunk:
return [text]
chunks = []
start = 0
while start < len(words):
end = start + words_per_chunk
chunk = " ".join(words[start:end])
chunks.append(chunk)
start = end - words_overlap
return chunks
async def get_embedding(text: str, client: httpx.AsyncClient | None = None) -> list[float]:
"""Get embedding vector for text using Ollama."""
should_close = False
if client is None:
client = httpx.AsyncClient(timeout=60)
should_close = True
try:
resp = await client.post(
f"{settings.ollama_base_url}/api/embeddings",
json={
"model": settings.ollama_embed_model,
"prompt": text,
},
)
resp.raise_for_status()
data = resp.json()
return data["embedding"]
finally:
if should_close:
await client.aclose()
async def get_embeddings_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:
"""Get embeddings for multiple texts sequentially."""
embeddings = []
for text in texts:
emb = await get_embedding(text, client)
embeddings.append(emb)
return embeddings
async def embed_experiences(batch_size: int = 20):
"""Chunk and embed all un-embedded experience reports."""
async with async_session() as db:
# Find experiences that don't have chunks yet
subq = select(DocumentChunk.source_id).where(
DocumentChunk.source_type == "experience"
).distinct()
result = await db.execute(
select(Experience).where(Experience.id.not_in(subq))
)
experiences = result.scalars().all()
logger.info(f"Found {len(experiences)} experiences to embed")
async with httpx.AsyncClient(timeout=60) as client:
total_chunks = 0
for i, exp in enumerate(experiences):
# Build a rich text representation
header = f"Experience Report: {exp.title}\n"
header += f"Substance: {exp.substance}\n"
if exp.category:
header += f"Category: {exp.category}\n"
if exp.gender:
header += f"Gender: {exp.gender}\n"
if exp.age:
header += f"Age: {exp.age}\n"
header += "\n"
full_text = header + exp.body
chunks = chunk_text(full_text, settings.chunk_size, settings.chunk_overlap)
for idx, chunk_text_content in enumerate(chunks):
embedding = await get_embedding(chunk_text_content, client)
metadata = json.dumps({
"title": exp.title,
"substance": exp.substance,
"category": exp.category,
"erowid_id": exp.erowid_id,
})
doc_chunk = DocumentChunk(
source_type="experience",
source_id=exp.id,
chunk_index=idx,
content=chunk_text_content,
metadata_json=metadata,
embedding=embedding,
)
db.add(doc_chunk)
total_chunks += 1
if (i + 1) % batch_size == 0:
await db.commit()
logger.info(f"Embedded {i + 1} experiences ({total_chunks} chunks)")
await db.commit()
logger.info(f"Done! Created {total_chunks} chunks from {len(experiences)} experiences")
return total_chunks
async def embed_substances(batch_size: int = 10):
"""Chunk and embed all un-embedded substance info pages."""
async with async_session() as db:
subq = select(DocumentChunk.source_id).where(
DocumentChunk.source_type == "substance"
).distinct()
result = await db.execute(
select(Substance).where(Substance.id.not_in(subq))
)
substances = result.scalars().all()
logger.info(f"Found {len(substances)} substances to embed")
async with httpx.AsyncClient(timeout=60) as client:
total_chunks = 0
for i, sub in enumerate(substances):
# Build rich text representation
sections = []
sections.append(f"Substance Information: {sub.name}")
if sub.category:
sections.append(f"Category: {sub.category}")
if sub.description:
sections.append(f"\nOverview:\n{sub.description}")
if sub.effects:
sections.append(f"\nEffects:\n{sub.effects}")
if sub.dosage:
sections.append(f"\nDosage:\n{sub.dosage}")
if sub.duration:
sections.append(f"\nDuration:\n{sub.duration}")
if sub.chemistry:
sections.append(f"\nChemistry:\n{sub.chemistry}")
if sub.health:
sections.append(f"\nHealth & Safety:\n{sub.health}")
if sub.law:
sections.append(f"\nLegal Status:\n{sub.law}")
full_text = "\n".join(sections)
chunks = chunk_text(full_text, settings.chunk_size, settings.chunk_overlap)
for idx, chunk_text_content in enumerate(chunks):
embedding = await get_embedding(chunk_text_content, client)
metadata = json.dumps({
"substance": sub.name,
"category": sub.category,
})
doc_chunk = DocumentChunk(
source_type="substance",
source_id=sub.id,
chunk_index=idx,
content=chunk_text_content,
metadata_json=metadata,
embedding=embedding,
)
db.add(doc_chunk)
total_chunks += 1
if (i + 1) % batch_size == 0:
await db.commit()
logger.info(f"Embedded {i + 1} substances ({total_chunks} chunks)")
await db.commit()
logger.info(f"Done! Created {total_chunks} chunks from {len(substances)} substances")
return total_chunks
async def embed_all():
"""Embed everything that hasn't been embedded yet."""
exp_chunks = await embed_experiences()
sub_chunks = await embed_substances()
return {"experience_chunks": exp_chunks, "substance_chunks": sub_chunks}