myco-bonding-curve/tests/test_signal_router.py

271 lines
9.8 KiB
Python

"""Tests for signal router — oracle signals → adaptive parameters."""
import math
import pytest
from src.primitives.signal_router import (
SignalRouterConfig,
AdaptiveParams,
RouterSignals,
extract_signals,
apply_signals,
route_signals,
simulate_signal_routing,
)
from src.primitives.twap_oracle import (
TWAPOracleParams,
TWAPOracleState,
create_oracle,
record_observation,
)
# --- Helpers ---
def _build_oracle(prices: list[float], dt: float = 1.0) -> TWAPOracleState:
"""Build an oracle with observations from a price list."""
oracle = create_oracle(TWAPOracleParams(default_window=len(prices) * dt))
for i, p in enumerate(prices):
oracle = record_observation(oracle, p, float(i) * dt)
return oracle
def _base_params() -> AdaptiveParams:
return AdaptiveParams(
flow_threshold=0.1,
pamm_alpha_bar=10.0,
surge_fee_rate=0.05,
oracle_multiplier_velocity=0.0,
)
# --- TestExtractSignals ---
class TestExtractSignals:
def test_insufficient_data_returns_invalid(self):
oracle = create_oracle()
signals = extract_signals(oracle, 1.0)
assert signals.is_valid is False
assert signals.twap_deviation == 0.0
def test_single_observation_returns_invalid(self):
oracle = create_oracle()
oracle = record_observation(oracle, 1.0, 0.0)
signals = extract_signals(oracle, 1.0)
assert signals.is_valid is False
def test_stable_prices_zero_deviation(self):
oracle = _build_oracle([1.0, 1.0, 1.0, 1.0, 1.0])
signals = extract_signals(oracle, 1.0)
assert signals.is_valid is True
assert abs(signals.twap_deviation) < 1e-10
assert signals.volatility == 0.0
def test_spot_above_twap_positive_deviation(self):
oracle = _build_oracle([1.0, 1.0, 1.0, 1.0, 1.0])
signals = extract_signals(oracle, 1.5)
assert signals.is_valid is True
assert signals.twap_deviation > 0
def test_spot_below_twap_negative_deviation(self):
oracle = _build_oracle([1.0, 1.0, 1.0, 1.0, 1.0])
signals = extract_signals(oracle, 0.5)
assert signals.is_valid is True
assert signals.twap_deviation < 0
def test_volatile_prices_nonzero_volatility(self):
oracle = _build_oracle([1.0, 1.5, 0.8, 1.3, 0.9, 1.1])
signals = extract_signals(oracle, 1.1)
assert signals.is_valid is True
assert signals.volatility > 0
# --- TestApplySignals ---
class TestApplySignals:
def test_invalid_signals_return_base(self):
base = _base_params()
config = SignalRouterConfig()
invalid = RouterSignals(0.0, 0.0, 1.0, 0.0, is_valid=False)
result = apply_signals(invalid, base, config)
assert result == base
def test_volatility_tightens_flow_threshold(self):
base = _base_params()
config = SignalRouterConfig(k_vol_flow=1.0)
signals = RouterSignals(
twap_deviation=0.0, volatility=0.2,
spot_price=1.0, twap=1.0, is_valid=True,
)
result = apply_signals(signals, base, config)
# flow *= (1 - 1.0 * 0.2) = 0.8 * base
assert result.flow_threshold < base.flow_threshold
assert abs(result.flow_threshold - 0.08) < 1e-10
def test_deviation_steepens_alpha(self):
base = _base_params()
config = SignalRouterConfig(k_dev_alpha=2.0)
signals = RouterSignals(
twap_deviation=0.1, volatility=0.0,
spot_price=1.1, twap=1.0, is_valid=True,
)
result = apply_signals(signals, base, config)
# alpha *= (1 + 2.0 * 0.1) = 1.2 * base
assert result.pamm_alpha_bar > base.pamm_alpha_bar
assert abs(result.pamm_alpha_bar - 12.0) < 1e-10
def test_volatility_increases_surge_fee(self):
base = _base_params()
config = SignalRouterConfig(k_vol_fee=1.5)
signals = RouterSignals(
twap_deviation=0.0, volatility=0.3,
spot_price=1.0, twap=1.0, is_valid=True,
)
result = apply_signals(signals, base, config)
# fee *= (1 + 1.5 * 0.3) = 1.45 * base
expected = 0.05 * 1.45
assert abs(result.surge_fee_rate - expected) < 1e-10
def test_deviation_sets_velocity(self):
base = _base_params()
config = SignalRouterConfig(k_oracle_vel=0.01)
signals = RouterSignals(
twap_deviation=-0.2, volatility=0.0,
spot_price=0.8, twap=1.0, is_valid=True,
)
result = apply_signals(signals, base, config)
assert abs(result.oracle_multiplier_velocity - (-0.002)) < 1e-10
def test_clamping_respects_bounds(self):
base = AdaptiveParams(
flow_threshold=0.1, pamm_alpha_bar=10.0,
surge_fee_rate=0.4, oracle_multiplier_velocity=0.0,
)
config = SignalRouterConfig(
k_vol_flow=10.0, k_vol_fee=10.0,
surge_fee_max=0.5, flow_threshold_min=0.01,
)
signals = RouterSignals(
twap_deviation=0.0, volatility=0.5,
spot_price=1.0, twap=1.0, is_valid=True,
)
result = apply_signals(signals, base, config)
assert result.flow_threshold >= config.flow_threshold_min
assert result.surge_fee_rate <= config.surge_fee_max
# --- TestRouteSignals ---
class TestRouteSignals:
def test_none_oracle_returns_base(self):
base = _base_params()
config = SignalRouterConfig()
result = route_signals(None, base, config)
assert result == base
def test_with_oracle_modifies_params(self):
oracle = _build_oracle([1.0, 1.0, 1.0, 1.2, 1.3])
base = _base_params()
config = SignalRouterConfig()
result = route_signals(oracle, base, config, current_spot=1.5)
# Spot > TWAP, so deviation > 0 → alpha increases
assert result.pamm_alpha_bar > base.pamm_alpha_bar
def test_uses_last_obs_when_no_spot(self):
oracle = _build_oracle([1.0, 1.0, 1.0])
base = _base_params()
config = SignalRouterConfig()
# current_spot=0 → should use last observation price
result = route_signals(oracle, base, config, current_spot=0.0)
# With spot == twap, deviation ≈ 0, so params near base
assert abs(result.flow_threshold - base.flow_threshold) < base.flow_threshold * 0.5
# --- TestKZeroDisablesLink ---
class TestKZeroDisablesLink:
def test_zero_k_vol_flow_leaves_threshold_unchanged(self):
base = _base_params()
config = SignalRouterConfig(k_vol_flow=0.0)
signals = RouterSignals(
twap_deviation=0.0, volatility=0.5,
spot_price=1.0, twap=1.0, is_valid=True,
)
result = apply_signals(signals, base, config)
assert abs(result.flow_threshold - base.flow_threshold) < 1e-10
def test_zero_k_dev_alpha_leaves_alpha_unchanged(self):
base = _base_params()
config = SignalRouterConfig(k_dev_alpha=0.0)
signals = RouterSignals(
twap_deviation=0.5, volatility=0.0,
spot_price=1.5, twap=1.0, is_valid=True,
)
result = apply_signals(signals, base, config)
assert abs(result.pamm_alpha_bar - base.pamm_alpha_bar) < 1e-10
def test_zero_k_vol_fee_leaves_fee_unchanged(self):
base = _base_params()
config = SignalRouterConfig(k_vol_fee=0.0)
signals = RouterSignals(
twap_deviation=0.0, volatility=0.5,
spot_price=1.0, twap=1.0, is_valid=True,
)
result = apply_signals(signals, base, config)
assert abs(result.surge_fee_rate - base.surge_fee_rate) < 1e-10
def test_zero_k_oracle_vel_leaves_velocity_zero(self):
base = _base_params()
config = SignalRouterConfig(k_oracle_vel=0.0)
signals = RouterSignals(
twap_deviation=0.3, volatility=0.0,
spot_price=1.3, twap=1.0, is_valid=True,
)
result = apply_signals(signals, base, config)
assert result.oracle_multiplier_velocity == 0.0
# --- TestSimulation ---
class TestSimulation:
def test_simulation_returns_correct_keys(self):
base = _base_params()
config = SignalRouterConfig()
prices = [1.0] * 10
result = simulate_signal_routing(base, config, prices)
assert set(result.keys()) == {
"times", "flow_threshold", "pamm_alpha_bar",
"surge_fee_rate", "oracle_velocity",
"twap_deviation", "volatility",
}
assert len(result["times"]) == 10
def test_volatile_trajectory_tightens_params(self):
base = _base_params()
config = SignalRouterConfig(k_vol_flow=2.0, k_vol_fee=2.0)
# Volatile trajectory
prices = [1.0, 1.5, 0.8, 1.3, 0.7, 1.4, 0.9, 1.2, 0.85, 1.1]
result = simulate_signal_routing(base, config, prices)
# After volatile period, flow_threshold should be below base
# (after enough observations for volatility to register)
final_flow = result["flow_threshold"][-1]
assert final_flow <= base.flow_threshold
def test_stable_trajectory_preserves_params(self):
base = _base_params()
config = SignalRouterConfig()
prices = [1.0] * 20
result = simulate_signal_routing(base, config, prices)
# Stable prices → zero volatility, zero deviation
# Params should stay at base (after oracle warms up)
assert abs(result["flow_threshold"][-1] - base.flow_threshold) < 1e-6
assert abs(result["surge_fee_rate"][-1] - base.surge_fee_rate) < 1e-6
def test_simulation_length_matches_trajectory(self):
base = _base_params()
config = SignalRouterConfig()
prices = [1.0 + 0.01 * i for i in range(50)]
result = simulate_signal_routing(base, config, prices)
for key in result:
assert len(result[key]) == 50