p2pwiki-ai/src/embeddings.py

257 lines
8.8 KiB
Python

"""Vector store setup and embedding generation using ChromaDB."""
import json
from pathlib import Path
from typing import Optional
import chromadb
from chromadb.config import Settings as ChromaSettings
from rich.console import Console
from rich.progress import Progress
from sentence_transformers import SentenceTransformer
from .config import settings
from .parser import WikiArticle
console = Console()
# Chunk size for embedding (in characters)
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
class WikiVectorStore:
"""Vector store for wiki articles using ChromaDB."""
def __init__(self, persist_dir: Optional[Path] = None):
self.persist_dir = persist_dir or settings.chroma_persist_dir
# Initialize ChromaDB
self.client = chromadb.PersistentClient(
path=str(self.persist_dir),
settings=ChromaSettings(anonymized_telemetry=False),
)
# Create or get collection
self.collection = self.client.get_or_create_collection(
name="wiki_articles",
metadata={"hnsw:space": "cosine"},
)
# Load embedding model
console.print(f"[cyan]Loading embedding model: {settings.embedding_model}[/cyan]")
self.model = SentenceTransformer(settings.embedding_model)
console.print("[green]Model loaded[/green]")
def _chunk_text(self, text: str, title: str) -> list[tuple[str, dict]]:
"""Split text into overlapping chunks with metadata."""
if len(text) <= CHUNK_SIZE:
return [(text, {"chunk_index": 0, "total_chunks": 1})]
chunks = []
start = 0
chunk_index = 0
while start < len(text):
end = start + CHUNK_SIZE
# Try to break at sentence boundary
if end < len(text):
# Look for sentence end within last 100 chars
for i in range(min(100, end - start)):
if text[end - i] in ".!?\n":
end = end - i + 1
break
chunk_text = text[start:end].strip()
if chunk_text:
# Prepend title for context
chunk_with_title = f"{title}\n\n{chunk_text}"
chunks.append(
(chunk_with_title, {"chunk_index": chunk_index, "total_chunks": -1})
)
chunk_index += 1
start = end - CHUNK_OVERLAP
# Update total_chunks
for i, (text, meta) in enumerate(chunks):
meta["total_chunks"] = len(chunks)
return chunks
def get_embedded_article_ids(self) -> set:
"""Get set of article IDs that are already embedded."""
results = self.collection.get(include=["metadatas"])
article_ids = set()
for meta in results["metadatas"]:
if meta and "article_id" in meta:
article_ids.add(meta["article_id"])
return article_ids
def add_articles(self, articles: list[WikiArticle], batch_size: int = 100, resume: bool = True):
"""Add articles to the vector store."""
console.print(f"[cyan]Processing {len(articles)} articles...[/cyan]")
# Check for already embedded articles if resuming
if resume:
embedded_ids = self.get_embedded_article_ids()
original_count = len(articles)
articles = [a for a in articles if a.id not in embedded_ids]
skipped = original_count - len(articles)
if skipped > 0:
console.print(f"[yellow]Skipping {skipped} already-embedded articles[/yellow]")
if not articles:
console.print("[green]All articles already embedded![/green]")
return
all_chunks = []
all_ids = []
all_metadatas = []
with Progress() as progress:
task = progress.add_task("[cyan]Chunking articles...", total=len(articles))
for article in articles:
if not article.plain_text:
progress.advance(task)
continue
chunks = self._chunk_text(article.plain_text, article.title)
for chunk_text, chunk_meta in chunks:
chunk_id = f"{article.id}_{chunk_meta['chunk_index']}"
metadata = {
"article_id": article.id,
"title": article.title,
"categories": ",".join(article.categories[:10]), # Limit categories
"timestamp": article.timestamp,
"chunk_index": chunk_meta["chunk_index"],
"total_chunks": chunk_meta["total_chunks"],
}
all_chunks.append(chunk_text)
all_ids.append(chunk_id)
all_metadatas.append(metadata)
progress.advance(task)
console.print(f"[cyan]Created {len(all_chunks)} chunks from {len(articles)} articles[/cyan]")
# Generate embeddings and add in batches
console.print("[cyan]Generating embeddings and adding to vector store...[/cyan]")
with Progress() as progress:
task = progress.add_task(
"[cyan]Embedding and storing...", total=len(all_chunks) // batch_size + 1
)
for i in range(0, len(all_chunks), batch_size):
batch_chunks = all_chunks[i : i + batch_size]
batch_ids = all_ids[i : i + batch_size]
batch_metadatas = all_metadatas[i : i + batch_size]
# Generate embeddings
embeddings = self.model.encode(batch_chunks, show_progress_bar=False)
# Add to collection
self.collection.add(
ids=batch_ids,
embeddings=embeddings.tolist(),
documents=batch_chunks,
metadatas=batch_metadatas,
)
progress.advance(task)
console.print(f"[green]Added {len(all_chunks)} chunks to vector store[/green]")
def search(
self,
query: str,
n_results: int = 5,
filter_categories: Optional[list[str]] = None,
) -> list[dict]:
"""Search for relevant chunks."""
query_embedding = self.model.encode([query])[0]
where_filter = None
if filter_categories:
# ChromaDB where filter for categories
where_filter = {
"$or": [{"categories": {"$contains": cat}} for cat in filter_categories]
}
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=n_results,
where=where_filter,
include=["documents", "metadatas", "distances"],
)
# Format results
formatted = []
if results["documents"] and results["documents"][0]:
for i, doc in enumerate(results["documents"][0]):
formatted.append(
{
"content": doc,
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
}
)
return formatted
def get_article_titles(self) -> list[str]:
"""Get all unique article titles in the store."""
# Get all metadata
results = self.collection.get(include=["metadatas"])
titles = set()
for meta in results["metadatas"]:
if meta and "title" in meta:
titles.add(meta["title"])
return sorted(titles)
def get_stats(self) -> dict:
"""Get statistics about the vector store."""
count = self.collection.count()
# Get sample of metadatas to count unique articles
sample = self.collection.get(limit=10000, include=["metadatas"])
unique_articles = len(set(m["article_id"] for m in sample["metadatas"] if m))
return {
"total_chunks": count,
"unique_articles_sampled": unique_articles,
"persist_dir": str(self.persist_dir),
}
def main():
"""CLI entry point for generating embeddings."""
articles_path = settings.data_dir / "articles.json"
if not articles_path.exists():
console.print(f"[red]Articles file not found: {articles_path}[/red]")
console.print("[yellow]Run 'python -m src.parser' first to parse XML dumps[/yellow]")
return
console.print(f"[cyan]Loading articles from {articles_path}...[/cyan]")
with open(articles_path, "r", encoding="utf-8") as f:
articles_data = json.load(f)
articles = [WikiArticle(**a) for a in articles_data]
console.print(f"[green]Loaded {len(articles)} articles[/green]")
store = WikiVectorStore()
store.add_articles(articles)
stats = store.get_stats()
console.print(f"[green]Vector store stats: {stats}[/green]")
if __name__ == "__main__":
main()