import json import logging from typing import AsyncGenerator import httpx from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.database import async_session from app.models import Experience, Substance, DocumentChunk logger = logging.getLogger(__name__) def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]: """Split text into overlapping chunks by approximate token count (words / 0.75).""" words = text.split() # Approximate: 1 token ~ 0.75 words words_per_chunk = int(chunk_size * 0.75) words_overlap = int(overlap * 0.75) if len(words) <= words_per_chunk: return [text] chunks = [] start = 0 while start < len(words): end = start + words_per_chunk chunk = " ".join(words[start:end]) chunks.append(chunk) start = end - words_overlap return chunks async def get_embedding(text: str, client: httpx.AsyncClient | None = None) -> list[float]: """Get embedding vector for text using Ollama.""" should_close = False if client is None: client = httpx.AsyncClient(timeout=60) should_close = True try: resp = await client.post( f"{settings.ollama_base_url}/api/embeddings", json={ "model": settings.ollama_embed_model, "prompt": text, }, ) resp.raise_for_status() data = resp.json() return data["embedding"] finally: if should_close: await client.aclose() async def get_embeddings_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]: """Get embeddings for multiple texts sequentially.""" embeddings = [] for text in texts: emb = await get_embedding(text, client) embeddings.append(emb) return embeddings async def embed_experiences(batch_size: int = 20): """Chunk and embed all un-embedded experience reports.""" async with async_session() as db: # Find experiences that don't have chunks yet subq = select(DocumentChunk.source_id).where( DocumentChunk.source_type == "experience" ).distinct() result = await db.execute( select(Experience).where(Experience.id.not_in(subq)) ) experiences = result.scalars().all() logger.info(f"Found {len(experiences)} experiences to embed") async with httpx.AsyncClient(timeout=60) as client: total_chunks = 0 for i, exp in enumerate(experiences): # Build a rich text representation header = f"Experience Report: {exp.title}\n" header += f"Substance: {exp.substance}\n" if exp.category: header += f"Category: {exp.category}\n" if exp.gender: header += f"Gender: {exp.gender}\n" if exp.age: header += f"Age: {exp.age}\n" header += "\n" full_text = header + exp.body chunks = chunk_text(full_text, settings.chunk_size, settings.chunk_overlap) for idx, chunk_text_content in enumerate(chunks): embedding = await get_embedding(chunk_text_content, client) metadata = json.dumps({ "title": exp.title, "substance": exp.substance, "category": exp.category, "erowid_id": exp.erowid_id, }) doc_chunk = DocumentChunk( source_type="experience", source_id=exp.id, chunk_index=idx, content=chunk_text_content, metadata_json=metadata, embedding=embedding, ) db.add(doc_chunk) total_chunks += 1 if (i + 1) % batch_size == 0: await db.commit() logger.info(f"Embedded {i + 1} experiences ({total_chunks} chunks)") await db.commit() logger.info(f"Done! Created {total_chunks} chunks from {len(experiences)} experiences") return total_chunks async def embed_substances(batch_size: int = 10): """Chunk and embed all un-embedded substance info pages.""" async with async_session() as db: subq = select(DocumentChunk.source_id).where( DocumentChunk.source_type == "substance" ).distinct() result = await db.execute( select(Substance).where(Substance.id.not_in(subq)) ) substances = result.scalars().all() logger.info(f"Found {len(substances)} substances to embed") async with httpx.AsyncClient(timeout=60) as client: total_chunks = 0 for i, sub in enumerate(substances): # Build rich text representation sections = [] sections.append(f"Substance Information: {sub.name}") if sub.category: sections.append(f"Category: {sub.category}") if sub.description: sections.append(f"\nOverview:\n{sub.description}") if sub.effects: sections.append(f"\nEffects:\n{sub.effects}") if sub.dosage: sections.append(f"\nDosage:\n{sub.dosage}") if sub.duration: sections.append(f"\nDuration:\n{sub.duration}") if sub.chemistry: sections.append(f"\nChemistry:\n{sub.chemistry}") if sub.health: sections.append(f"\nHealth & Safety:\n{sub.health}") if sub.law: sections.append(f"\nLegal Status:\n{sub.law}") full_text = "\n".join(sections) chunks = chunk_text(full_text, settings.chunk_size, settings.chunk_overlap) for idx, chunk_text_content in enumerate(chunks): embedding = await get_embedding(chunk_text_content, client) metadata = json.dumps({ "substance": sub.name, "category": sub.category, }) doc_chunk = DocumentChunk( source_type="substance", source_id=sub.id, chunk_index=idx, content=chunk_text_content, metadata_json=metadata, embedding=embedding, ) db.add(doc_chunk) total_chunks += 1 if (i + 1) % batch_size == 0: await db.commit() logger.info(f"Embedded {i + 1} substances ({total_chunks} chunks)") await db.commit() logger.info(f"Done! Created {total_chunks} chunks from {len(substances)} substances") return total_chunks async def embed_all(): """Embed everything that hasn't been embedded yet.""" exp_chunks = await embed_experiences() sub_chunks = await embed_substances() return {"experience_chunks": exp_chunks, "substance_chunks": sub_chunks}