"""Tests for TWAP oracle primitive.""" import math import pytest from src.primitives.twap_oracle import ( PriceObservation, TWAPOracleParams, TWAPOracleState, create_oracle, record_observation, compute_twap, compute_vwap, spot_vs_twap_deviation, get_volatility, ) class TestPriceObservation: def test_record_single(self): oracle = create_oracle() oracle = record_observation(oracle, price=1.0, time=0.0) assert len(oracle.observations) == 1 assert oracle.observations[0].price == 1.0 assert oracle.observations[0].time == 0.0 def test_record_multiple(self): oracle = create_oracle() oracle = record_observation(oracle, 1.0, 0.0) oracle = record_observation(oracle, 1.5, 1.0) oracle = record_observation(oracle, 2.0, 2.0) assert len(oracle.observations) == 3 assert oracle.observations[-1].price == 2.0 def test_record_with_volume(self): oracle = create_oracle() oracle = record_observation(oracle, 1.0, 0.0, volume=100.0) assert oracle.observations[0].volume == 100.0 def test_cumulative_accumulator_updates(self): oracle = create_oracle() oracle = record_observation(oracle, 10.0, 0.0) assert oracle.cumulative_price_time == 0.0 # First obs, no dt oracle = record_observation(oracle, 12.0, 5.0) # cumulative += 10.0 * 5.0 = 50.0 assert oracle.cumulative_price_time == 50.0 oracle = record_observation(oracle, 8.0, 10.0) # cumulative += 12.0 * 5.0 = 60.0 → total 110.0 assert oracle.cumulative_price_time == 110.0 class TestTWAPComputation: def test_constant_price(self): """TWAP of a constant price should be that price.""" oracle = create_oracle() for t in range(10): oracle = record_observation(oracle, 5.0, float(t)) twap = compute_twap(oracle, window=100.0) assert abs(twap - 5.0) < 1e-10 def test_rising_price(self): """TWAP of a rising price should be between start and end.""" oracle = create_oracle() for t in range(10): oracle = record_observation(oracle, 1.0 + t * 0.1, float(t)) twap = compute_twap(oracle, window=100.0) first_price = 1.0 last_price = 1.9 assert first_price < twap < last_price def test_window_slicing(self): """Only observations within the window should contribute.""" oracle = create_oracle() # Old observations at low price for t in range(10): oracle = record_observation(oracle, 1.0, float(t)) # Recent observations at high price for t in range(10, 20): oracle = record_observation(oracle, 10.0, float(t)) # Short window should give ~10.0 short_twap = compute_twap(oracle, window=5.0) assert short_twap == pytest.approx(10.0, abs=0.1) # Long window should be between 1 and 10 long_twap = compute_twap(oracle, window=100.0) assert 1.0 < long_twap < 10.0 def test_insufficient_data(self): """Should return 0.0 with fewer than min_observations.""" oracle = create_oracle() assert compute_twap(oracle) == 0.0 oracle = record_observation(oracle, 5.0, 0.0) assert compute_twap(oracle) == 0.0 # Only 1, need 2 def test_default_window(self): """Uses default_window from params when window=None.""" params = TWAPOracleParams(default_window=5.0) oracle = create_oracle(params) for t in range(20): oracle = record_observation(oracle, float(t), float(t)) twap_default = compute_twap(oracle) twap_explicit = compute_twap(oracle, window=5.0) assert twap_default == twap_explicit def test_two_observations(self): """Minimum case: two observations.""" oracle = create_oracle() oracle = record_observation(oracle, 2.0, 0.0) oracle = record_observation(oracle, 4.0, 1.0) twap = compute_twap(oracle, window=10.0) # Only one interval: price 2.0 for dt=1.0 assert twap == pytest.approx(2.0, abs=0.01) class TestVWAP: def test_equal_volume(self): """With equal volumes, VWAP = simple average.""" oracle = create_oracle() oracle = record_observation(oracle, 10.0, 0.0, volume=100.0) oracle = record_observation(oracle, 20.0, 1.0, volume=100.0) vwap = compute_vwap(oracle, window=10.0) assert vwap == pytest.approx(15.0, abs=0.01) def test_volume_weighting(self): """Higher volume at lower price should pull VWAP down.""" oracle = create_oracle() oracle = record_observation(oracle, 10.0, 0.0, volume=900.0) oracle = record_observation(oracle, 20.0, 1.0, volume=100.0) vwap = compute_vwap(oracle, window=10.0) # (10*900 + 20*100) / 1000 = 11.0 assert vwap == pytest.approx(11.0, abs=0.01) def test_zero_volume(self): """VWAP returns 0.0 when total volume is zero.""" oracle = create_oracle() oracle = record_observation(oracle, 10.0, 0.0, volume=0.0) oracle = record_observation(oracle, 20.0, 1.0, volume=0.0) assert compute_vwap(oracle, window=10.0) == 0.0 def test_insufficient_data(self): oracle = create_oracle() assert compute_vwap(oracle) == 0.0 class TestDeviation: def test_spot_above_twap(self): """Positive deviation when spot > TWAP.""" oracle = create_oracle() for t in range(10): oracle = record_observation(oracle, 10.0, float(t)) dev = spot_vs_twap_deviation(oracle, current_spot=12.0, window=100.0) assert dev > 0 assert dev == pytest.approx(0.2, abs=0.01) def test_spot_below_twap(self): """Negative deviation when spot < TWAP.""" oracle = create_oracle() for t in range(10): oracle = record_observation(oracle, 10.0, float(t)) dev = spot_vs_twap_deviation(oracle, current_spot=8.0, window=100.0) assert dev < 0 assert dev == pytest.approx(-0.2, abs=0.01) def test_spot_equals_twap(self): """Zero deviation when spot == TWAP.""" oracle = create_oracle() for t in range(10): oracle = record_observation(oracle, 10.0, float(t)) dev = spot_vs_twap_deviation(oracle, current_spot=10.0, window=100.0) assert abs(dev) < 1e-10 def test_no_data(self): """Returns 0.0 with no TWAP data.""" oracle = create_oracle() assert spot_vs_twap_deviation(oracle, 10.0) == 0.0 class TestVolatility: def test_constant_price_zero_vol(self): """Constant price = zero volatility.""" oracle = create_oracle() for t in range(20): oracle = record_observation(oracle, 100.0, float(t)) vol = get_volatility(oracle, window=100.0) assert vol == pytest.approx(0.0, abs=1e-10) def test_varying_price_positive_vol(self): """Varying prices = positive volatility.""" oracle = create_oracle() prices = [100, 105, 98, 110, 95, 108, 102, 112, 97, 106] for t, p in enumerate(prices): oracle = record_observation(oracle, float(p), float(t)) vol = get_volatility(oracle, window=100.0) assert vol > 0 def test_insufficient_data(self): """Need at least 3 observations for volatility.""" oracle = create_oracle() oracle = record_observation(oracle, 100.0, 0.0) oracle = record_observation(oracle, 105.0, 1.0) assert get_volatility(oracle) == 0.0 def test_volatility_increases_with_swings(self): """Larger price swings = higher volatility.""" # Small swings oracle_small = create_oracle() for t in range(20): price = 100.0 + (1 if t % 2 == 0 else -1) oracle_small = record_observation(oracle_small, price, float(t)) # Large swings oracle_large = create_oracle() for t in range(20): price = 100.0 + (20 if t % 2 == 0 else -20) oracle_large = record_observation(oracle_large, price, float(t)) vol_small = get_volatility(oracle_small, window=100.0) vol_large = get_volatility(oracle_large, window=100.0) assert vol_large > vol_small class TestRingBuffer: def test_max_observations_respected(self): """Ring buffer evicts oldest when full.""" params = TWAPOracleParams(max_observations=5) oracle = create_oracle(params) for t in range(10): oracle = record_observation(oracle, float(t), float(t)) assert len(oracle.observations) == 5 # Oldest should be t=5 (indices 5-9 kept) assert oracle.observations[0].time == 5.0 assert oracle.observations[-1].time == 9.0 def test_small_buffer(self): """Even buffer size 2 works.""" params = TWAPOracleParams(max_observations=2, min_observations=2) oracle = create_oracle(params) oracle = record_observation(oracle, 1.0, 0.0) oracle = record_observation(oracle, 2.0, 1.0) oracle = record_observation(oracle, 3.0, 2.0) assert len(oracle.observations) == 2 assert oracle.observations[0].price == 2.0 assert oracle.observations[1].price == 3.0 def test_cumulative_survives_eviction(self): """Cumulative accumulator keeps accruing even after eviction.""" params = TWAPOracleParams(max_observations=3) oracle = create_oracle(params) for t in range(6): oracle = record_observation(oracle, 10.0, float(t)) # cumulative = 10*1 + 10*1 + 10*1 + 10*1 + 10*1 = 50.0 assert oracle.cumulative_price_time == 50.0 assert len(oracle.observations) == 3