257 lines
8.8 KiB
Python
257 lines
8.8 KiB
Python
"""Vector store setup and embedding generation using ChromaDB."""
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import chromadb
|
|
from chromadb.config import Settings as ChromaSettings
|
|
from rich.console import Console
|
|
from rich.progress import Progress
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
from .config import settings
|
|
from .parser import WikiArticle
|
|
|
|
console = Console()
|
|
|
|
# Chunk size for embedding (in characters)
|
|
CHUNK_SIZE = 1000
|
|
CHUNK_OVERLAP = 200
|
|
|
|
|
|
class WikiVectorStore:
|
|
"""Vector store for wiki articles using ChromaDB."""
|
|
|
|
def __init__(self, persist_dir: Optional[Path] = None):
|
|
self.persist_dir = persist_dir or settings.chroma_persist_dir
|
|
|
|
# Initialize ChromaDB
|
|
self.client = chromadb.PersistentClient(
|
|
path=str(self.persist_dir),
|
|
settings=ChromaSettings(anonymized_telemetry=False),
|
|
)
|
|
|
|
# Create or get collection
|
|
self.collection = self.client.get_or_create_collection(
|
|
name="wiki_articles",
|
|
metadata={"hnsw:space": "cosine"},
|
|
)
|
|
|
|
# Load embedding model
|
|
console.print(f"[cyan]Loading embedding model: {settings.embedding_model}[/cyan]")
|
|
self.model = SentenceTransformer(settings.embedding_model)
|
|
console.print("[green]Model loaded[/green]")
|
|
|
|
def _chunk_text(self, text: str, title: str) -> list[tuple[str, dict]]:
|
|
"""Split text into overlapping chunks with metadata."""
|
|
if len(text) <= CHUNK_SIZE:
|
|
return [(text, {"chunk_index": 0, "total_chunks": 1})]
|
|
|
|
chunks = []
|
|
start = 0
|
|
chunk_index = 0
|
|
|
|
while start < len(text):
|
|
end = start + CHUNK_SIZE
|
|
|
|
# Try to break at sentence boundary
|
|
if end < len(text):
|
|
# Look for sentence end within last 100 chars
|
|
for i in range(min(100, end - start)):
|
|
if text[end - i] in ".!?\n":
|
|
end = end - i + 1
|
|
break
|
|
|
|
chunk_text = text[start:end].strip()
|
|
if chunk_text:
|
|
# Prepend title for context
|
|
chunk_with_title = f"{title}\n\n{chunk_text}"
|
|
chunks.append(
|
|
(chunk_with_title, {"chunk_index": chunk_index, "total_chunks": -1})
|
|
)
|
|
chunk_index += 1
|
|
|
|
start = end - CHUNK_OVERLAP
|
|
|
|
# Update total_chunks
|
|
for i, (text, meta) in enumerate(chunks):
|
|
meta["total_chunks"] = len(chunks)
|
|
|
|
return chunks
|
|
|
|
def get_embedded_article_ids(self) -> set:
|
|
"""Get set of article IDs that are already embedded."""
|
|
results = self.collection.get(include=["metadatas"])
|
|
article_ids = set()
|
|
for meta in results["metadatas"]:
|
|
if meta and "article_id" in meta:
|
|
article_ids.add(meta["article_id"])
|
|
return article_ids
|
|
|
|
def add_articles(self, articles: list[WikiArticle], batch_size: int = 100, resume: bool = True):
|
|
"""Add articles to the vector store."""
|
|
console.print(f"[cyan]Processing {len(articles)} articles...[/cyan]")
|
|
|
|
# Check for already embedded articles if resuming
|
|
if resume:
|
|
embedded_ids = self.get_embedded_article_ids()
|
|
original_count = len(articles)
|
|
articles = [a for a in articles if a.id not in embedded_ids]
|
|
skipped = original_count - len(articles)
|
|
if skipped > 0:
|
|
console.print(f"[yellow]Skipping {skipped} already-embedded articles[/yellow]")
|
|
if not articles:
|
|
console.print("[green]All articles already embedded![/green]")
|
|
return
|
|
|
|
all_chunks = []
|
|
all_ids = []
|
|
all_metadatas = []
|
|
|
|
with Progress() as progress:
|
|
task = progress.add_task("[cyan]Chunking articles...", total=len(articles))
|
|
|
|
for article in articles:
|
|
if not article.plain_text:
|
|
progress.advance(task)
|
|
continue
|
|
|
|
chunks = self._chunk_text(article.plain_text, article.title)
|
|
|
|
for chunk_text, chunk_meta in chunks:
|
|
chunk_id = f"{article.id}_{chunk_meta['chunk_index']}"
|
|
|
|
metadata = {
|
|
"article_id": article.id,
|
|
"title": article.title,
|
|
"categories": ",".join(article.categories[:10]), # Limit categories
|
|
"timestamp": article.timestamp,
|
|
"chunk_index": chunk_meta["chunk_index"],
|
|
"total_chunks": chunk_meta["total_chunks"],
|
|
}
|
|
|
|
all_chunks.append(chunk_text)
|
|
all_ids.append(chunk_id)
|
|
all_metadatas.append(metadata)
|
|
|
|
progress.advance(task)
|
|
|
|
console.print(f"[cyan]Created {len(all_chunks)} chunks from {len(articles)} articles[/cyan]")
|
|
|
|
# Generate embeddings and add in batches
|
|
console.print("[cyan]Generating embeddings and adding to vector store...[/cyan]")
|
|
|
|
with Progress() as progress:
|
|
task = progress.add_task(
|
|
"[cyan]Embedding and storing...", total=len(all_chunks) // batch_size + 1
|
|
)
|
|
|
|
for i in range(0, len(all_chunks), batch_size):
|
|
batch_chunks = all_chunks[i : i + batch_size]
|
|
batch_ids = all_ids[i : i + batch_size]
|
|
batch_metadatas = all_metadatas[i : i + batch_size]
|
|
|
|
# Generate embeddings
|
|
embeddings = self.model.encode(batch_chunks, show_progress_bar=False)
|
|
|
|
# Add to collection
|
|
self.collection.add(
|
|
ids=batch_ids,
|
|
embeddings=embeddings.tolist(),
|
|
documents=batch_chunks,
|
|
metadatas=batch_metadatas,
|
|
)
|
|
|
|
progress.advance(task)
|
|
|
|
console.print(f"[green]Added {len(all_chunks)} chunks to vector store[/green]")
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
n_results: int = 5,
|
|
filter_categories: Optional[list[str]] = None,
|
|
) -> list[dict]:
|
|
"""Search for relevant chunks."""
|
|
query_embedding = self.model.encode([query])[0]
|
|
|
|
where_filter = None
|
|
if filter_categories:
|
|
# ChromaDB where filter for categories
|
|
where_filter = {
|
|
"$or": [{"categories": {"$contains": cat}} for cat in filter_categories]
|
|
}
|
|
|
|
results = self.collection.query(
|
|
query_embeddings=[query_embedding.tolist()],
|
|
n_results=n_results,
|
|
where=where_filter,
|
|
include=["documents", "metadatas", "distances"],
|
|
)
|
|
|
|
# Format results
|
|
formatted = []
|
|
if results["documents"] and results["documents"][0]:
|
|
for i, doc in enumerate(results["documents"][0]):
|
|
formatted.append(
|
|
{
|
|
"content": doc,
|
|
"metadata": results["metadatas"][0][i],
|
|
"distance": results["distances"][0][i],
|
|
}
|
|
)
|
|
|
|
return formatted
|
|
|
|
def get_article_titles(self) -> list[str]:
|
|
"""Get all unique article titles in the store."""
|
|
# Get all metadata
|
|
results = self.collection.get(include=["metadatas"])
|
|
titles = set()
|
|
for meta in results["metadatas"]:
|
|
if meta and "title" in meta:
|
|
titles.add(meta["title"])
|
|
return sorted(titles)
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get statistics about the vector store."""
|
|
count = self.collection.count()
|
|
|
|
# Get sample of metadatas to count unique articles
|
|
sample = self.collection.get(limit=10000, include=["metadatas"])
|
|
unique_articles = len(set(m["article_id"] for m in sample["metadatas"] if m))
|
|
|
|
return {
|
|
"total_chunks": count,
|
|
"unique_articles_sampled": unique_articles,
|
|
"persist_dir": str(self.persist_dir),
|
|
}
|
|
|
|
|
|
def main():
|
|
"""CLI entry point for generating embeddings."""
|
|
articles_path = settings.data_dir / "articles.json"
|
|
|
|
if not articles_path.exists():
|
|
console.print(f"[red]Articles file not found: {articles_path}[/red]")
|
|
console.print("[yellow]Run 'python -m src.parser' first to parse XML dumps[/yellow]")
|
|
return
|
|
|
|
console.print(f"[cyan]Loading articles from {articles_path}...[/cyan]")
|
|
with open(articles_path, "r", encoding="utf-8") as f:
|
|
articles_data = json.load(f)
|
|
|
|
articles = [WikiArticle(**a) for a in articles_data]
|
|
console.print(f"[green]Loaded {len(articles)} articles[/green]")
|
|
|
|
store = WikiVectorStore()
|
|
store.add_articles(articles)
|
|
|
|
stats = store.get_stats()
|
|
console.print(f"[green]Vector store stats: {stats}[/green]")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|