"""Vector store setup and embedding generation using ChromaDB.""" import json from pathlib import Path from typing import Optional import chromadb from chromadb.config import Settings as ChromaSettings from rich.console import Console from rich.progress import Progress from sentence_transformers import SentenceTransformer from .config import settings from .parser import WikiArticle console = Console() # Chunk size for embedding (in characters) CHUNK_SIZE = 1000 CHUNK_OVERLAP = 200 class WikiVectorStore: """Vector store for wiki articles using ChromaDB.""" def __init__(self, persist_dir: Optional[Path] = None): self.persist_dir = persist_dir or settings.chroma_persist_dir # Initialize ChromaDB self.client = chromadb.PersistentClient( path=str(self.persist_dir), settings=ChromaSettings(anonymized_telemetry=False), ) # Create or get collection self.collection = self.client.get_or_create_collection( name="wiki_articles", metadata={"hnsw:space": "cosine"}, ) # Load embedding model console.print(f"[cyan]Loading embedding model: {settings.embedding_model}[/cyan]") self.model = SentenceTransformer(settings.embedding_model) console.print("[green]Model loaded[/green]") def _chunk_text(self, text: str, title: str) -> list[tuple[str, dict]]: """Split text into overlapping chunks with metadata.""" if len(text) <= CHUNK_SIZE: return [(text, {"chunk_index": 0, "total_chunks": 1})] chunks = [] start = 0 chunk_index = 0 while start < len(text): end = start + CHUNK_SIZE # Try to break at sentence boundary if end < len(text): # Look for sentence end within last 100 chars for i in range(min(100, end - start)): if text[end - i] in ".!?\n": end = end - i + 1 break chunk_text = text[start:end].strip() if chunk_text: # Prepend title for context chunk_with_title = f"{title}\n\n{chunk_text}" chunks.append( (chunk_with_title, {"chunk_index": chunk_index, "total_chunks": -1}) ) chunk_index += 1 start = end - CHUNK_OVERLAP # Update total_chunks for i, (text, meta) in enumerate(chunks): meta["total_chunks"] = len(chunks) return chunks def get_embedded_article_ids(self) -> set: """Get set of article IDs that are already embedded.""" results = self.collection.get(include=["metadatas"]) article_ids = set() for meta in results["metadatas"]: if meta and "article_id" in meta: article_ids.add(meta["article_id"]) return article_ids def add_articles(self, articles: list[WikiArticle], batch_size: int = 100, resume: bool = True): """Add articles to the vector store.""" console.print(f"[cyan]Processing {len(articles)} articles...[/cyan]") # Check for already embedded articles if resuming if resume: embedded_ids = self.get_embedded_article_ids() original_count = len(articles) articles = [a for a in articles if a.id not in embedded_ids] skipped = original_count - len(articles) if skipped > 0: console.print(f"[yellow]Skipping {skipped} already-embedded articles[/yellow]") if not articles: console.print("[green]All articles already embedded![/green]") return all_chunks = [] all_ids = [] all_metadatas = [] with Progress() as progress: task = progress.add_task("[cyan]Chunking articles...", total=len(articles)) for article in articles: if not article.plain_text: progress.advance(task) continue chunks = self._chunk_text(article.plain_text, article.title) for chunk_text, chunk_meta in chunks: chunk_id = f"{article.id}_{chunk_meta['chunk_index']}" metadata = { "article_id": article.id, "title": article.title, "categories": ",".join(article.categories[:10]), # Limit categories "timestamp": article.timestamp, "chunk_index": chunk_meta["chunk_index"], "total_chunks": chunk_meta["total_chunks"], } all_chunks.append(chunk_text) all_ids.append(chunk_id) all_metadatas.append(metadata) progress.advance(task) console.print(f"[cyan]Created {len(all_chunks)} chunks from {len(articles)} articles[/cyan]") # Generate embeddings and add in batches console.print("[cyan]Generating embeddings and adding to vector store...[/cyan]") with Progress() as progress: task = progress.add_task( "[cyan]Embedding and storing...", total=len(all_chunks) // batch_size + 1 ) for i in range(0, len(all_chunks), batch_size): batch_chunks = all_chunks[i : i + batch_size] batch_ids = all_ids[i : i + batch_size] batch_metadatas = all_metadatas[i : i + batch_size] # Generate embeddings embeddings = self.model.encode(batch_chunks, show_progress_bar=False) # Add to collection self.collection.add( ids=batch_ids, embeddings=embeddings.tolist(), documents=batch_chunks, metadatas=batch_metadatas, ) progress.advance(task) console.print(f"[green]Added {len(all_chunks)} chunks to vector store[/green]") def search( self, query: str, n_results: int = 5, filter_categories: Optional[list[str]] = None, ) -> list[dict]: """Search for relevant chunks.""" query_embedding = self.model.encode([query])[0] where_filter = None if filter_categories: # ChromaDB where filter for categories where_filter = { "$or": [{"categories": {"$contains": cat}} for cat in filter_categories] } results = self.collection.query( query_embeddings=[query_embedding.tolist()], n_results=n_results, where=where_filter, include=["documents", "metadatas", "distances"], ) # Format results formatted = [] if results["documents"] and results["documents"][0]: for i, doc in enumerate(results["documents"][0]): formatted.append( { "content": doc, "metadata": results["metadatas"][0][i], "distance": results["distances"][0][i], } ) return formatted def get_article_titles(self) -> list[str]: """Get all unique article titles in the store.""" # Get all metadata results = self.collection.get(include=["metadatas"]) titles = set() for meta in results["metadatas"]: if meta and "title" in meta: titles.add(meta["title"]) return sorted(titles) def get_stats(self) -> dict: """Get statistics about the vector store.""" count = self.collection.count() # Get sample of metadatas to count unique articles sample = self.collection.get(limit=10000, include=["metadatas"]) unique_articles = len(set(m["article_id"] for m in sample["metadatas"] if m)) return { "total_chunks": count, "unique_articles_sampled": unique_articles, "persist_dir": str(self.persist_dir), } def main(): """CLI entry point for generating embeddings.""" articles_path = settings.data_dir / "articles.json" if not articles_path.exists(): console.print(f"[red]Articles file not found: {articles_path}[/red]") console.print("[yellow]Run 'python -m src.parser' first to parse XML dumps[/yellow]") return console.print(f"[cyan]Loading articles from {articles_path}...[/cyan]") with open(articles_path, "r", encoding="utf-8") as f: articles_data = json.load(f) articles = [WikiArticle(**a) for a in articles_data] console.print(f"[green]Loaded {len(articles)} articles[/green]") store = WikiVectorStore() store.add_articles(articles) stats = store.get_stats() console.print(f"[green]Vector store stats: {stats}[/green]") if __name__ == "__main__": main()