spore-commons/node/spore_node/federation/relay.py

95 lines
3.0 KiB
Python

"""Store-and-forward relay for offline federation peers.
Uses Redis sorted sets with TTL. Events queued per peer, flushed on schedule.
"""
import json
import time
import logging
from typing import Any
import httpx
log = logging.getLogger(__name__)
RELAY_PREFIX = "spore:relay:"
RELAY_TTL_SECONDS = 7 * 24 * 3600 # 7 days
class FederationRelay:
def __init__(self, redis_url: str):
self._redis_url = redis_url
self._redis = None
async def _get_redis(self):
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self._redis_url)
return self._redis
async def enqueue(self, peer_rid: str, event: dict) -> None:
"""Queue an event for a peer."""
r = await self._get_redis()
key = f"{RELAY_PREFIX}{peer_rid}"
score = time.time()
await r.zadd(key, {json.dumps(event): score})
# Set TTL on the key
await r.expire(key, RELAY_TTL_SECONDS)
async def flush_peer(self, peer_rid: str, peer_url: str) -> int:
"""Attempt to flush queued events to a peer. Returns count flushed."""
r = await self._get_redis()
key = f"{RELAY_PREFIX}{peer_rid}"
# Get all queued events
events = await r.zrangebyscore(key, "-inf", "+inf")
if not events:
return 0
flushed = 0
async with httpx.AsyncClient(timeout=30) as client:
for raw in events:
event = json.loads(raw)
try:
resp = await client.post(
f"{peer_url}/koi-net/events/broadcast",
json=event,
)
resp.raise_for_status()
await r.zrem(key, raw)
flushed += 1
except Exception as e:
log.warning(f"Failed to flush to {peer_rid}: {e}")
break # Stop on first failure
return flushed
async def flush_all(self, peers: list[dict]) -> dict[str, int]:
"""Flush events to all peers. Returns {peer_rid: count}."""
results = {}
for peer in peers:
if peer.get("handshake_status") != "approved":
continue
count = await self.flush_peer(peer["node_rid"], peer["node_url"])
if count > 0:
results[peer["node_rid"]] = count
return results
async def pending_count(self, peer_rid: str) -> int:
"""Get count of pending events for a peer."""
r = await self._get_redis()
key = f"{RELAY_PREFIX}{peer_rid}"
return await r.zcard(key)
async def prune_expired(self) -> int:
"""Remove events older than TTL. Returns count removed."""
r = await self._get_redis()
cutoff = time.time() - RELAY_TTL_SECONDS
total = 0
async for key in r.scan_iter(f"{RELAY_PREFIX}*"):
removed = await r.zremrangebyscore(key, "-inf", cutoff)
total += removed
return total