95 lines
3.0 KiB
Python
95 lines
3.0 KiB
Python
"""Store-and-forward relay for offline federation peers.
|
|
|
|
Uses Redis sorted sets with TTL. Events queued per peer, flushed on schedule.
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
import logging
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
RELAY_PREFIX = "spore:relay:"
|
|
RELAY_TTL_SECONDS = 7 * 24 * 3600 # 7 days
|
|
|
|
|
|
class FederationRelay:
|
|
def __init__(self, redis_url: str):
|
|
self._redis_url = redis_url
|
|
self._redis = None
|
|
|
|
async def _get_redis(self):
|
|
if self._redis is None:
|
|
import redis.asyncio as aioredis
|
|
self._redis = aioredis.from_url(self._redis_url)
|
|
return self._redis
|
|
|
|
async def enqueue(self, peer_rid: str, event: dict) -> None:
|
|
"""Queue an event for a peer."""
|
|
r = await self._get_redis()
|
|
key = f"{RELAY_PREFIX}{peer_rid}"
|
|
score = time.time()
|
|
await r.zadd(key, {json.dumps(event): score})
|
|
# Set TTL on the key
|
|
await r.expire(key, RELAY_TTL_SECONDS)
|
|
|
|
async def flush_peer(self, peer_rid: str, peer_url: str) -> int:
|
|
"""Attempt to flush queued events to a peer. Returns count flushed."""
|
|
r = await self._get_redis()
|
|
key = f"{RELAY_PREFIX}{peer_rid}"
|
|
|
|
# Get all queued events
|
|
events = await r.zrangebyscore(key, "-inf", "+inf")
|
|
if not events:
|
|
return 0
|
|
|
|
flushed = 0
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
for raw in events:
|
|
event = json.loads(raw)
|
|
try:
|
|
resp = await client.post(
|
|
f"{peer_url}/koi-net/events/broadcast",
|
|
json=event,
|
|
)
|
|
resp.raise_for_status()
|
|
await r.zrem(key, raw)
|
|
flushed += 1
|
|
except Exception as e:
|
|
log.warning(f"Failed to flush to {peer_rid}: {e}")
|
|
break # Stop on first failure
|
|
|
|
return flushed
|
|
|
|
async def flush_all(self, peers: list[dict]) -> dict[str, int]:
|
|
"""Flush events to all peers. Returns {peer_rid: count}."""
|
|
results = {}
|
|
for peer in peers:
|
|
if peer.get("handshake_status") != "approved":
|
|
continue
|
|
count = await self.flush_peer(peer["node_rid"], peer["node_url"])
|
|
if count > 0:
|
|
results[peer["node_rid"]] = count
|
|
return results
|
|
|
|
async def pending_count(self, peer_rid: str) -> int:
|
|
"""Get count of pending events for a peer."""
|
|
r = await self._get_redis()
|
|
key = f"{RELAY_PREFIX}{peer_rid}"
|
|
return await r.zcard(key)
|
|
|
|
async def prune_expired(self) -> int:
|
|
"""Remove events older than TTL. Returns count removed."""
|
|
r = await self._get_redis()
|
|
cutoff = time.time() - RELAY_TTL_SECONDS
|
|
total = 0
|
|
|
|
async for key in r.scan_iter(f"{RELAY_PREFIX}*"):
|
|
removed = await r.zremrangebyscore(key, "-inf", cutoff)
|
|
total += removed
|
|
|
|
return total
|