271 lines
9.8 KiB
Python
271 lines
9.8 KiB
Python
"""Tests for signal router — oracle signals → adaptive parameters."""
|
|
|
|
import math
|
|
import pytest
|
|
|
|
from src.primitives.signal_router import (
|
|
SignalRouterConfig,
|
|
AdaptiveParams,
|
|
RouterSignals,
|
|
extract_signals,
|
|
apply_signals,
|
|
route_signals,
|
|
simulate_signal_routing,
|
|
)
|
|
from src.primitives.twap_oracle import (
|
|
TWAPOracleParams,
|
|
TWAPOracleState,
|
|
create_oracle,
|
|
record_observation,
|
|
)
|
|
|
|
|
|
# --- Helpers ---
|
|
|
|
def _build_oracle(prices: list[float], dt: float = 1.0) -> TWAPOracleState:
|
|
"""Build an oracle with observations from a price list."""
|
|
oracle = create_oracle(TWAPOracleParams(default_window=len(prices) * dt))
|
|
for i, p in enumerate(prices):
|
|
oracle = record_observation(oracle, p, float(i) * dt)
|
|
return oracle
|
|
|
|
|
|
def _base_params() -> AdaptiveParams:
|
|
return AdaptiveParams(
|
|
flow_threshold=0.1,
|
|
pamm_alpha_bar=10.0,
|
|
surge_fee_rate=0.05,
|
|
oracle_multiplier_velocity=0.0,
|
|
)
|
|
|
|
|
|
# --- TestExtractSignals ---
|
|
|
|
class TestExtractSignals:
|
|
def test_insufficient_data_returns_invalid(self):
|
|
oracle = create_oracle()
|
|
signals = extract_signals(oracle, 1.0)
|
|
assert signals.is_valid is False
|
|
assert signals.twap_deviation == 0.0
|
|
|
|
def test_single_observation_returns_invalid(self):
|
|
oracle = create_oracle()
|
|
oracle = record_observation(oracle, 1.0, 0.0)
|
|
signals = extract_signals(oracle, 1.0)
|
|
assert signals.is_valid is False
|
|
|
|
def test_stable_prices_zero_deviation(self):
|
|
oracle = _build_oracle([1.0, 1.0, 1.0, 1.0, 1.0])
|
|
signals = extract_signals(oracle, 1.0)
|
|
assert signals.is_valid is True
|
|
assert abs(signals.twap_deviation) < 1e-10
|
|
assert signals.volatility == 0.0
|
|
|
|
def test_spot_above_twap_positive_deviation(self):
|
|
oracle = _build_oracle([1.0, 1.0, 1.0, 1.0, 1.0])
|
|
signals = extract_signals(oracle, 1.5)
|
|
assert signals.is_valid is True
|
|
assert signals.twap_deviation > 0
|
|
|
|
def test_spot_below_twap_negative_deviation(self):
|
|
oracle = _build_oracle([1.0, 1.0, 1.0, 1.0, 1.0])
|
|
signals = extract_signals(oracle, 0.5)
|
|
assert signals.is_valid is True
|
|
assert signals.twap_deviation < 0
|
|
|
|
def test_volatile_prices_nonzero_volatility(self):
|
|
oracle = _build_oracle([1.0, 1.5, 0.8, 1.3, 0.9, 1.1])
|
|
signals = extract_signals(oracle, 1.1)
|
|
assert signals.is_valid is True
|
|
assert signals.volatility > 0
|
|
|
|
|
|
# --- TestApplySignals ---
|
|
|
|
class TestApplySignals:
|
|
def test_invalid_signals_return_base(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig()
|
|
invalid = RouterSignals(0.0, 0.0, 1.0, 0.0, is_valid=False)
|
|
result = apply_signals(invalid, base, config)
|
|
assert result == base
|
|
|
|
def test_volatility_tightens_flow_threshold(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig(k_vol_flow=1.0)
|
|
signals = RouterSignals(
|
|
twap_deviation=0.0, volatility=0.2,
|
|
spot_price=1.0, twap=1.0, is_valid=True,
|
|
)
|
|
result = apply_signals(signals, base, config)
|
|
# flow *= (1 - 1.0 * 0.2) = 0.8 * base
|
|
assert result.flow_threshold < base.flow_threshold
|
|
assert abs(result.flow_threshold - 0.08) < 1e-10
|
|
|
|
def test_deviation_steepens_alpha(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig(k_dev_alpha=2.0)
|
|
signals = RouterSignals(
|
|
twap_deviation=0.1, volatility=0.0,
|
|
spot_price=1.1, twap=1.0, is_valid=True,
|
|
)
|
|
result = apply_signals(signals, base, config)
|
|
# alpha *= (1 + 2.0 * 0.1) = 1.2 * base
|
|
assert result.pamm_alpha_bar > base.pamm_alpha_bar
|
|
assert abs(result.pamm_alpha_bar - 12.0) < 1e-10
|
|
|
|
def test_volatility_increases_surge_fee(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig(k_vol_fee=1.5)
|
|
signals = RouterSignals(
|
|
twap_deviation=0.0, volatility=0.3,
|
|
spot_price=1.0, twap=1.0, is_valid=True,
|
|
)
|
|
result = apply_signals(signals, base, config)
|
|
# fee *= (1 + 1.5 * 0.3) = 1.45 * base
|
|
expected = 0.05 * 1.45
|
|
assert abs(result.surge_fee_rate - expected) < 1e-10
|
|
|
|
def test_deviation_sets_velocity(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig(k_oracle_vel=0.01)
|
|
signals = RouterSignals(
|
|
twap_deviation=-0.2, volatility=0.0,
|
|
spot_price=0.8, twap=1.0, is_valid=True,
|
|
)
|
|
result = apply_signals(signals, base, config)
|
|
assert abs(result.oracle_multiplier_velocity - (-0.002)) < 1e-10
|
|
|
|
def test_clamping_respects_bounds(self):
|
|
base = AdaptiveParams(
|
|
flow_threshold=0.1, pamm_alpha_bar=10.0,
|
|
surge_fee_rate=0.4, oracle_multiplier_velocity=0.0,
|
|
)
|
|
config = SignalRouterConfig(
|
|
k_vol_flow=10.0, k_vol_fee=10.0,
|
|
surge_fee_max=0.5, flow_threshold_min=0.01,
|
|
)
|
|
signals = RouterSignals(
|
|
twap_deviation=0.0, volatility=0.5,
|
|
spot_price=1.0, twap=1.0, is_valid=True,
|
|
)
|
|
result = apply_signals(signals, base, config)
|
|
assert result.flow_threshold >= config.flow_threshold_min
|
|
assert result.surge_fee_rate <= config.surge_fee_max
|
|
|
|
|
|
# --- TestRouteSignals ---
|
|
|
|
class TestRouteSignals:
|
|
def test_none_oracle_returns_base(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig()
|
|
result = route_signals(None, base, config)
|
|
assert result == base
|
|
|
|
def test_with_oracle_modifies_params(self):
|
|
oracle = _build_oracle([1.0, 1.0, 1.0, 1.2, 1.3])
|
|
base = _base_params()
|
|
config = SignalRouterConfig()
|
|
result = route_signals(oracle, base, config, current_spot=1.5)
|
|
# Spot > TWAP, so deviation > 0 → alpha increases
|
|
assert result.pamm_alpha_bar > base.pamm_alpha_bar
|
|
|
|
def test_uses_last_obs_when_no_spot(self):
|
|
oracle = _build_oracle([1.0, 1.0, 1.0])
|
|
base = _base_params()
|
|
config = SignalRouterConfig()
|
|
# current_spot=0 → should use last observation price
|
|
result = route_signals(oracle, base, config, current_spot=0.0)
|
|
# With spot == twap, deviation ≈ 0, so params near base
|
|
assert abs(result.flow_threshold - base.flow_threshold) < base.flow_threshold * 0.5
|
|
|
|
|
|
# --- TestKZeroDisablesLink ---
|
|
|
|
class TestKZeroDisablesLink:
|
|
def test_zero_k_vol_flow_leaves_threshold_unchanged(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig(k_vol_flow=0.0)
|
|
signals = RouterSignals(
|
|
twap_deviation=0.0, volatility=0.5,
|
|
spot_price=1.0, twap=1.0, is_valid=True,
|
|
)
|
|
result = apply_signals(signals, base, config)
|
|
assert abs(result.flow_threshold - base.flow_threshold) < 1e-10
|
|
|
|
def test_zero_k_dev_alpha_leaves_alpha_unchanged(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig(k_dev_alpha=0.0)
|
|
signals = RouterSignals(
|
|
twap_deviation=0.5, volatility=0.0,
|
|
spot_price=1.5, twap=1.0, is_valid=True,
|
|
)
|
|
result = apply_signals(signals, base, config)
|
|
assert abs(result.pamm_alpha_bar - base.pamm_alpha_bar) < 1e-10
|
|
|
|
def test_zero_k_vol_fee_leaves_fee_unchanged(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig(k_vol_fee=0.0)
|
|
signals = RouterSignals(
|
|
twap_deviation=0.0, volatility=0.5,
|
|
spot_price=1.0, twap=1.0, is_valid=True,
|
|
)
|
|
result = apply_signals(signals, base, config)
|
|
assert abs(result.surge_fee_rate - base.surge_fee_rate) < 1e-10
|
|
|
|
def test_zero_k_oracle_vel_leaves_velocity_zero(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig(k_oracle_vel=0.0)
|
|
signals = RouterSignals(
|
|
twap_deviation=0.3, volatility=0.0,
|
|
spot_price=1.3, twap=1.0, is_valid=True,
|
|
)
|
|
result = apply_signals(signals, base, config)
|
|
assert result.oracle_multiplier_velocity == 0.0
|
|
|
|
|
|
# --- TestSimulation ---
|
|
|
|
class TestSimulation:
|
|
def test_simulation_returns_correct_keys(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig()
|
|
prices = [1.0] * 10
|
|
result = simulate_signal_routing(base, config, prices)
|
|
assert set(result.keys()) == {
|
|
"times", "flow_threshold", "pamm_alpha_bar",
|
|
"surge_fee_rate", "oracle_velocity",
|
|
"twap_deviation", "volatility",
|
|
}
|
|
assert len(result["times"]) == 10
|
|
|
|
def test_volatile_trajectory_tightens_params(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig(k_vol_flow=2.0, k_vol_fee=2.0)
|
|
# Volatile trajectory
|
|
prices = [1.0, 1.5, 0.8, 1.3, 0.7, 1.4, 0.9, 1.2, 0.85, 1.1]
|
|
result = simulate_signal_routing(base, config, prices)
|
|
# After volatile period, flow_threshold should be below base
|
|
# (after enough observations for volatility to register)
|
|
final_flow = result["flow_threshold"][-1]
|
|
assert final_flow <= base.flow_threshold
|
|
|
|
def test_stable_trajectory_preserves_params(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig()
|
|
prices = [1.0] * 20
|
|
result = simulate_signal_routing(base, config, prices)
|
|
# Stable prices → zero volatility, zero deviation
|
|
# Params should stay at base (after oracle warms up)
|
|
assert abs(result["flow_threshold"][-1] - base.flow_threshold) < 1e-6
|
|
assert abs(result["surge_fee_rate"][-1] - base.surge_fee_rate) < 1e-6
|
|
|
|
def test_simulation_length_matches_trajectory(self):
|
|
base = _base_params()
|
|
config = SignalRouterConfig()
|
|
prices = [1.0 + 0.01 * i for i in range(50)]
|
|
result = simulate_signal_routing(base, config, prices)
|
|
for key in result:
|
|
assert len(result[key]) == 50
|