205 lines
7.4 KiB
Python
205 lines
7.4 KiB
Python
import json
|
|
import logging
|
|
from typing import AsyncGenerator
|
|
|
|
import httpx
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.config import settings
|
|
from app.database import async_session
|
|
from app.models import Experience, Substance, DocumentChunk
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Persistent HTTP client for Ollama calls (avoids TCP handshake per request)
|
|
_ollama_client: httpx.AsyncClient | None = None
|
|
|
|
|
|
def _get_ollama_client() -> httpx.AsyncClient:
|
|
global _ollama_client
|
|
if _ollama_client is None or _ollama_client.is_closed:
|
|
_ollama_client = httpx.AsyncClient(
|
|
base_url=settings.ollama_base_url,
|
|
timeout=httpx.Timeout(connect=10, read=120, write=10, pool=10),
|
|
)
|
|
return _ollama_client
|
|
|
|
|
|
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
|
|
"""Split text into overlapping chunks by approximate token count (words / 0.75)."""
|
|
words = text.split()
|
|
words_per_chunk = int(chunk_size * 0.75)
|
|
words_overlap = int(overlap * 0.75)
|
|
|
|
if len(words) <= words_per_chunk:
|
|
return [text]
|
|
|
|
chunks = []
|
|
start = 0
|
|
while start < len(words):
|
|
end = start + words_per_chunk
|
|
chunk = " ".join(words[start:end])
|
|
chunks.append(chunk)
|
|
start = end - words_overlap
|
|
|
|
return chunks
|
|
|
|
|
|
async def get_embedding(text: str, client: httpx.AsyncClient | None = None) -> list[float]:
|
|
"""Get embedding vector for text using Ollama."""
|
|
c = client or _get_ollama_client()
|
|
|
|
resp = await c.post(
|
|
"/api/embeddings",
|
|
json={
|
|
"model": settings.ollama_embed_model,
|
|
"prompt": text,
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
return data["embedding"]
|
|
|
|
|
|
async def get_embeddings_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:
|
|
"""Get embeddings for multiple texts sequentially."""
|
|
embeddings = []
|
|
for text in texts:
|
|
emb = await get_embedding(text, client)
|
|
embeddings.append(emb)
|
|
return embeddings
|
|
|
|
|
|
async def embed_experiences(batch_size: int = 20):
|
|
"""Chunk and embed all un-embedded experience reports."""
|
|
async with async_session() as db:
|
|
# Find experiences that don't have chunks yet
|
|
subq = select(DocumentChunk.source_id).where(
|
|
DocumentChunk.source_type == "experience"
|
|
).distinct()
|
|
result = await db.execute(
|
|
select(Experience).where(Experience.id.not_in(subq))
|
|
)
|
|
experiences = result.scalars().all()
|
|
logger.info(f"Found {len(experiences)} experiences to embed")
|
|
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
total_chunks = 0
|
|
for i, exp in enumerate(experiences):
|
|
# Build a rich text representation
|
|
header = f"Experience Report: {exp.title}\n"
|
|
header += f"Substance: {exp.substance}\n"
|
|
if exp.category:
|
|
header += f"Category: {exp.category}\n"
|
|
if exp.gender:
|
|
header += f"Gender: {exp.gender}\n"
|
|
if exp.age:
|
|
header += f"Age: {exp.age}\n"
|
|
header += "\n"
|
|
|
|
full_text = header + exp.body
|
|
chunks = chunk_text(full_text, settings.chunk_size, settings.chunk_overlap)
|
|
|
|
for idx, chunk_text_content in enumerate(chunks):
|
|
embedding = await get_embedding(chunk_text_content, client)
|
|
|
|
metadata = json.dumps({
|
|
"title": exp.title,
|
|
"substance": exp.substance,
|
|
"category": exp.category,
|
|
"erowid_id": exp.erowid_id,
|
|
})
|
|
|
|
doc_chunk = DocumentChunk(
|
|
source_type="experience",
|
|
source_id=exp.id,
|
|
chunk_index=idx,
|
|
content=chunk_text_content,
|
|
metadata_json=metadata,
|
|
embedding=embedding,
|
|
)
|
|
db.add(doc_chunk)
|
|
total_chunks += 1
|
|
|
|
if (i + 1) % batch_size == 0:
|
|
await db.commit()
|
|
logger.info(f"Embedded {i + 1} experiences ({total_chunks} chunks)")
|
|
|
|
await db.commit()
|
|
logger.info(f"Done! Created {total_chunks} chunks from {len(experiences)} experiences")
|
|
return total_chunks
|
|
|
|
|
|
async def embed_substances(batch_size: int = 10):
|
|
"""Chunk and embed all un-embedded substance info pages."""
|
|
async with async_session() as db:
|
|
subq = select(DocumentChunk.source_id).where(
|
|
DocumentChunk.source_type == "substance"
|
|
).distinct()
|
|
result = await db.execute(
|
|
select(Substance).where(Substance.id.not_in(subq))
|
|
)
|
|
substances = result.scalars().all()
|
|
logger.info(f"Found {len(substances)} substances to embed")
|
|
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
total_chunks = 0
|
|
for i, sub in enumerate(substances):
|
|
# Build rich text representation
|
|
sections = []
|
|
sections.append(f"Substance Information: {sub.name}")
|
|
if sub.category:
|
|
sections.append(f"Category: {sub.category}")
|
|
if sub.description:
|
|
sections.append(f"\nOverview:\n{sub.description}")
|
|
if sub.effects:
|
|
sections.append(f"\nEffects:\n{sub.effects}")
|
|
if sub.dosage:
|
|
sections.append(f"\nDosage:\n{sub.dosage}")
|
|
if sub.duration:
|
|
sections.append(f"\nDuration:\n{sub.duration}")
|
|
if sub.chemistry:
|
|
sections.append(f"\nChemistry:\n{sub.chemistry}")
|
|
if sub.health:
|
|
sections.append(f"\nHealth & Safety:\n{sub.health}")
|
|
if sub.law:
|
|
sections.append(f"\nLegal Status:\n{sub.law}")
|
|
|
|
full_text = "\n".join(sections)
|
|
chunks = chunk_text(full_text, settings.chunk_size, settings.chunk_overlap)
|
|
|
|
for idx, chunk_text_content in enumerate(chunks):
|
|
embedding = await get_embedding(chunk_text_content, client)
|
|
|
|
metadata = json.dumps({
|
|
"substance": sub.name,
|
|
"category": sub.category,
|
|
})
|
|
|
|
doc_chunk = DocumentChunk(
|
|
source_type="substance",
|
|
source_id=sub.id,
|
|
chunk_index=idx,
|
|
content=chunk_text_content,
|
|
metadata_json=metadata,
|
|
embedding=embedding,
|
|
)
|
|
db.add(doc_chunk)
|
|
total_chunks += 1
|
|
|
|
if (i + 1) % batch_size == 0:
|
|
await db.commit()
|
|
logger.info(f"Embedded {i + 1} substances ({total_chunks} chunks)")
|
|
|
|
await db.commit()
|
|
logger.info(f"Done! Created {total_chunks} chunks from {len(substances)} substances")
|
|
return total_chunks
|
|
|
|
|
|
async def embed_all():
|
|
"""Embed everything that hasn't been embedded yet."""
|
|
exp_chunks = await embed_experiences()
|
|
sub_chunks = await embed_substances()
|
|
return {"experience_chunks": exp_chunks, "substance_chunks": sub_chunks}
|