"""Store-and-forward relay for offline federation peers. Uses Redis sorted sets with TTL. Events queued per peer, flushed on schedule. """ import json import time import logging from typing import Any import httpx log = logging.getLogger(__name__) RELAY_PREFIX = "spore:relay:" RELAY_TTL_SECONDS = 7 * 24 * 3600 # 7 days class FederationRelay: def __init__(self, redis_url: str): self._redis_url = redis_url self._redis = None async def _get_redis(self): if self._redis is None: import redis.asyncio as aioredis self._redis = aioredis.from_url(self._redis_url) return self._redis async def enqueue(self, peer_rid: str, event: dict) -> None: """Queue an event for a peer.""" r = await self._get_redis() key = f"{RELAY_PREFIX}{peer_rid}" score = time.time() await r.zadd(key, {json.dumps(event): score}) # Set TTL on the key await r.expire(key, RELAY_TTL_SECONDS) async def flush_peer(self, peer_rid: str, peer_url: str) -> int: """Attempt to flush queued events to a peer. Returns count flushed.""" r = await self._get_redis() key = f"{RELAY_PREFIX}{peer_rid}" # Get all queued events events = await r.zrangebyscore(key, "-inf", "+inf") if not events: return 0 flushed = 0 async with httpx.AsyncClient(timeout=30) as client: for raw in events: event = json.loads(raw) try: resp = await client.post( f"{peer_url}/koi-net/events/broadcast", json=event, ) resp.raise_for_status() await r.zrem(key, raw) flushed += 1 except Exception as e: log.warning(f"Failed to flush to {peer_rid}: {e}") break # Stop on first failure return flushed async def flush_all(self, peers: list[dict]) -> dict[str, int]: """Flush events to all peers. Returns {peer_rid: count}.""" results = {} for peer in peers: if peer.get("handshake_status") != "approved": continue count = await self.flush_peer(peer["node_rid"], peer["node_url"]) if count > 0: results[peer["node_rid"]] = count return results async def pending_count(self, peer_rid: str) -> int: """Get count of pending events for a peer.""" r = await self._get_redis() key = f"{RELAY_PREFIX}{peer_rid}" return await r.zcard(key) async def prune_expired(self) -> int: """Remove events older than TTL. Returns count removed.""" r = await self._get_redis() cutoff = time.time() - RELAY_TTL_SECONDS total = 0 async for key in r.scan_iter(f"{RELAY_PREFIX}*"): removed = await r.zremrangebyscore(key, "-inf", cutoff) total += removed return total