263 lines
9.6 KiB
Python
263 lines
9.6 KiB
Python
"""Tests for TWAP oracle primitive."""
|
|
|
|
import math
|
|
import pytest
|
|
from src.primitives.twap_oracle import (
|
|
PriceObservation, TWAPOracleParams, TWAPOracleState,
|
|
create_oracle, record_observation,
|
|
compute_twap, compute_vwap,
|
|
spot_vs_twap_deviation, get_volatility,
|
|
)
|
|
|
|
|
|
class TestPriceObservation:
|
|
def test_record_single(self):
|
|
oracle = create_oracle()
|
|
oracle = record_observation(oracle, price=1.0, time=0.0)
|
|
assert len(oracle.observations) == 1
|
|
assert oracle.observations[0].price == 1.0
|
|
assert oracle.observations[0].time == 0.0
|
|
|
|
def test_record_multiple(self):
|
|
oracle = create_oracle()
|
|
oracle = record_observation(oracle, 1.0, 0.0)
|
|
oracle = record_observation(oracle, 1.5, 1.0)
|
|
oracle = record_observation(oracle, 2.0, 2.0)
|
|
assert len(oracle.observations) == 3
|
|
assert oracle.observations[-1].price == 2.0
|
|
|
|
def test_record_with_volume(self):
|
|
oracle = create_oracle()
|
|
oracle = record_observation(oracle, 1.0, 0.0, volume=100.0)
|
|
assert oracle.observations[0].volume == 100.0
|
|
|
|
def test_cumulative_accumulator_updates(self):
|
|
oracle = create_oracle()
|
|
oracle = record_observation(oracle, 10.0, 0.0)
|
|
assert oracle.cumulative_price_time == 0.0 # First obs, no dt
|
|
|
|
oracle = record_observation(oracle, 12.0, 5.0)
|
|
# cumulative += 10.0 * 5.0 = 50.0
|
|
assert oracle.cumulative_price_time == 50.0
|
|
|
|
oracle = record_observation(oracle, 8.0, 10.0)
|
|
# cumulative += 12.0 * 5.0 = 60.0 → total 110.0
|
|
assert oracle.cumulative_price_time == 110.0
|
|
|
|
|
|
class TestTWAPComputation:
|
|
def test_constant_price(self):
|
|
"""TWAP of a constant price should be that price."""
|
|
oracle = create_oracle()
|
|
for t in range(10):
|
|
oracle = record_observation(oracle, 5.0, float(t))
|
|
twap = compute_twap(oracle, window=100.0)
|
|
assert abs(twap - 5.0) < 1e-10
|
|
|
|
def test_rising_price(self):
|
|
"""TWAP of a rising price should be between start and end."""
|
|
oracle = create_oracle()
|
|
for t in range(10):
|
|
oracle = record_observation(oracle, 1.0 + t * 0.1, float(t))
|
|
|
|
twap = compute_twap(oracle, window=100.0)
|
|
first_price = 1.0
|
|
last_price = 1.9
|
|
assert first_price < twap < last_price
|
|
|
|
def test_window_slicing(self):
|
|
"""Only observations within the window should contribute."""
|
|
oracle = create_oracle()
|
|
# Old observations at low price
|
|
for t in range(10):
|
|
oracle = record_observation(oracle, 1.0, float(t))
|
|
# Recent observations at high price
|
|
for t in range(10, 20):
|
|
oracle = record_observation(oracle, 10.0, float(t))
|
|
|
|
# Short window should give ~10.0
|
|
short_twap = compute_twap(oracle, window=5.0)
|
|
assert short_twap == pytest.approx(10.0, abs=0.1)
|
|
|
|
# Long window should be between 1 and 10
|
|
long_twap = compute_twap(oracle, window=100.0)
|
|
assert 1.0 < long_twap < 10.0
|
|
|
|
def test_insufficient_data(self):
|
|
"""Should return 0.0 with fewer than min_observations."""
|
|
oracle = create_oracle()
|
|
assert compute_twap(oracle) == 0.0
|
|
|
|
oracle = record_observation(oracle, 5.0, 0.0)
|
|
assert compute_twap(oracle) == 0.0 # Only 1, need 2
|
|
|
|
def test_default_window(self):
|
|
"""Uses default_window from params when window=None."""
|
|
params = TWAPOracleParams(default_window=5.0)
|
|
oracle = create_oracle(params)
|
|
for t in range(20):
|
|
oracle = record_observation(oracle, float(t), float(t))
|
|
|
|
twap_default = compute_twap(oracle)
|
|
twap_explicit = compute_twap(oracle, window=5.0)
|
|
assert twap_default == twap_explicit
|
|
|
|
def test_two_observations(self):
|
|
"""Minimum case: two observations."""
|
|
oracle = create_oracle()
|
|
oracle = record_observation(oracle, 2.0, 0.0)
|
|
oracle = record_observation(oracle, 4.0, 1.0)
|
|
twap = compute_twap(oracle, window=10.0)
|
|
# Only one interval: price 2.0 for dt=1.0
|
|
assert twap == pytest.approx(2.0, abs=0.01)
|
|
|
|
|
|
class TestVWAP:
|
|
def test_equal_volume(self):
|
|
"""With equal volumes, VWAP = simple average."""
|
|
oracle = create_oracle()
|
|
oracle = record_observation(oracle, 10.0, 0.0, volume=100.0)
|
|
oracle = record_observation(oracle, 20.0, 1.0, volume=100.0)
|
|
vwap = compute_vwap(oracle, window=10.0)
|
|
assert vwap == pytest.approx(15.0, abs=0.01)
|
|
|
|
def test_volume_weighting(self):
|
|
"""Higher volume at lower price should pull VWAP down."""
|
|
oracle = create_oracle()
|
|
oracle = record_observation(oracle, 10.0, 0.0, volume=900.0)
|
|
oracle = record_observation(oracle, 20.0, 1.0, volume=100.0)
|
|
vwap = compute_vwap(oracle, window=10.0)
|
|
# (10*900 + 20*100) / 1000 = 11.0
|
|
assert vwap == pytest.approx(11.0, abs=0.01)
|
|
|
|
def test_zero_volume(self):
|
|
"""VWAP returns 0.0 when total volume is zero."""
|
|
oracle = create_oracle()
|
|
oracle = record_observation(oracle, 10.0, 0.0, volume=0.0)
|
|
oracle = record_observation(oracle, 20.0, 1.0, volume=0.0)
|
|
assert compute_vwap(oracle, window=10.0) == 0.0
|
|
|
|
def test_insufficient_data(self):
|
|
oracle = create_oracle()
|
|
assert compute_vwap(oracle) == 0.0
|
|
|
|
|
|
class TestDeviation:
|
|
def test_spot_above_twap(self):
|
|
"""Positive deviation when spot > TWAP."""
|
|
oracle = create_oracle()
|
|
for t in range(10):
|
|
oracle = record_observation(oracle, 10.0, float(t))
|
|
|
|
dev = spot_vs_twap_deviation(oracle, current_spot=12.0, window=100.0)
|
|
assert dev > 0
|
|
assert dev == pytest.approx(0.2, abs=0.01)
|
|
|
|
def test_spot_below_twap(self):
|
|
"""Negative deviation when spot < TWAP."""
|
|
oracle = create_oracle()
|
|
for t in range(10):
|
|
oracle = record_observation(oracle, 10.0, float(t))
|
|
|
|
dev = spot_vs_twap_deviation(oracle, current_spot=8.0, window=100.0)
|
|
assert dev < 0
|
|
assert dev == pytest.approx(-0.2, abs=0.01)
|
|
|
|
def test_spot_equals_twap(self):
|
|
"""Zero deviation when spot == TWAP."""
|
|
oracle = create_oracle()
|
|
for t in range(10):
|
|
oracle = record_observation(oracle, 10.0, float(t))
|
|
|
|
dev = spot_vs_twap_deviation(oracle, current_spot=10.0, window=100.0)
|
|
assert abs(dev) < 1e-10
|
|
|
|
def test_no_data(self):
|
|
"""Returns 0.0 with no TWAP data."""
|
|
oracle = create_oracle()
|
|
assert spot_vs_twap_deviation(oracle, 10.0) == 0.0
|
|
|
|
|
|
class TestVolatility:
|
|
def test_constant_price_zero_vol(self):
|
|
"""Constant price = zero volatility."""
|
|
oracle = create_oracle()
|
|
for t in range(20):
|
|
oracle = record_observation(oracle, 100.0, float(t))
|
|
vol = get_volatility(oracle, window=100.0)
|
|
assert vol == pytest.approx(0.0, abs=1e-10)
|
|
|
|
def test_varying_price_positive_vol(self):
|
|
"""Varying prices = positive volatility."""
|
|
oracle = create_oracle()
|
|
prices = [100, 105, 98, 110, 95, 108, 102, 112, 97, 106]
|
|
for t, p in enumerate(prices):
|
|
oracle = record_observation(oracle, float(p), float(t))
|
|
vol = get_volatility(oracle, window=100.0)
|
|
assert vol > 0
|
|
|
|
def test_insufficient_data(self):
|
|
"""Need at least 3 observations for volatility."""
|
|
oracle = create_oracle()
|
|
oracle = record_observation(oracle, 100.0, 0.0)
|
|
oracle = record_observation(oracle, 105.0, 1.0)
|
|
assert get_volatility(oracle) == 0.0
|
|
|
|
def test_volatility_increases_with_swings(self):
|
|
"""Larger price swings = higher volatility."""
|
|
# Small swings
|
|
oracle_small = create_oracle()
|
|
for t in range(20):
|
|
price = 100.0 + (1 if t % 2 == 0 else -1)
|
|
oracle_small = record_observation(oracle_small, price, float(t))
|
|
|
|
# Large swings
|
|
oracle_large = create_oracle()
|
|
for t in range(20):
|
|
price = 100.0 + (20 if t % 2 == 0 else -20)
|
|
oracle_large = record_observation(oracle_large, price, float(t))
|
|
|
|
vol_small = get_volatility(oracle_small, window=100.0)
|
|
vol_large = get_volatility(oracle_large, window=100.0)
|
|
assert vol_large > vol_small
|
|
|
|
|
|
class TestRingBuffer:
|
|
def test_max_observations_respected(self):
|
|
"""Ring buffer evicts oldest when full."""
|
|
params = TWAPOracleParams(max_observations=5)
|
|
oracle = create_oracle(params)
|
|
|
|
for t in range(10):
|
|
oracle = record_observation(oracle, float(t), float(t))
|
|
|
|
assert len(oracle.observations) == 5
|
|
# Oldest should be t=5 (indices 5-9 kept)
|
|
assert oracle.observations[0].time == 5.0
|
|
assert oracle.observations[-1].time == 9.0
|
|
|
|
def test_small_buffer(self):
|
|
"""Even buffer size 2 works."""
|
|
params = TWAPOracleParams(max_observations=2, min_observations=2)
|
|
oracle = create_oracle(params)
|
|
|
|
oracle = record_observation(oracle, 1.0, 0.0)
|
|
oracle = record_observation(oracle, 2.0, 1.0)
|
|
oracle = record_observation(oracle, 3.0, 2.0)
|
|
|
|
assert len(oracle.observations) == 2
|
|
assert oracle.observations[0].price == 2.0
|
|
assert oracle.observations[1].price == 3.0
|
|
|
|
def test_cumulative_survives_eviction(self):
|
|
"""Cumulative accumulator keeps accruing even after eviction."""
|
|
params = TWAPOracleParams(max_observations=3)
|
|
oracle = create_oracle(params)
|
|
|
|
for t in range(6):
|
|
oracle = record_observation(oracle, 10.0, float(t))
|
|
|
|
# cumulative = 10*1 + 10*1 + 10*1 + 10*1 + 10*1 = 50.0
|
|
assert oracle.cumulative_price_time == 50.0
|
|
assert len(oracle.observations) == 3
|