myco-bonding-curve/tests/test_twap_oracle.py

263 lines
9.6 KiB
Python

"""Tests for TWAP oracle primitive."""
import math
import pytest
from src.primitives.twap_oracle import (
PriceObservation, TWAPOracleParams, TWAPOracleState,
create_oracle, record_observation,
compute_twap, compute_vwap,
spot_vs_twap_deviation, get_volatility,
)
class TestPriceObservation:
def test_record_single(self):
oracle = create_oracle()
oracle = record_observation(oracle, price=1.0, time=0.0)
assert len(oracle.observations) == 1
assert oracle.observations[0].price == 1.0
assert oracle.observations[0].time == 0.0
def test_record_multiple(self):
oracle = create_oracle()
oracle = record_observation(oracle, 1.0, 0.0)
oracle = record_observation(oracle, 1.5, 1.0)
oracle = record_observation(oracle, 2.0, 2.0)
assert len(oracle.observations) == 3
assert oracle.observations[-1].price == 2.0
def test_record_with_volume(self):
oracle = create_oracle()
oracle = record_observation(oracle, 1.0, 0.0, volume=100.0)
assert oracle.observations[0].volume == 100.0
def test_cumulative_accumulator_updates(self):
oracle = create_oracle()
oracle = record_observation(oracle, 10.0, 0.0)
assert oracle.cumulative_price_time == 0.0 # First obs, no dt
oracle = record_observation(oracle, 12.0, 5.0)
# cumulative += 10.0 * 5.0 = 50.0
assert oracle.cumulative_price_time == 50.0
oracle = record_observation(oracle, 8.0, 10.0)
# cumulative += 12.0 * 5.0 = 60.0 → total 110.0
assert oracle.cumulative_price_time == 110.0
class TestTWAPComputation:
def test_constant_price(self):
"""TWAP of a constant price should be that price."""
oracle = create_oracle()
for t in range(10):
oracle = record_observation(oracle, 5.0, float(t))
twap = compute_twap(oracle, window=100.0)
assert abs(twap - 5.0) < 1e-10
def test_rising_price(self):
"""TWAP of a rising price should be between start and end."""
oracle = create_oracle()
for t in range(10):
oracle = record_observation(oracle, 1.0 + t * 0.1, float(t))
twap = compute_twap(oracle, window=100.0)
first_price = 1.0
last_price = 1.9
assert first_price < twap < last_price
def test_window_slicing(self):
"""Only observations within the window should contribute."""
oracle = create_oracle()
# Old observations at low price
for t in range(10):
oracle = record_observation(oracle, 1.0, float(t))
# Recent observations at high price
for t in range(10, 20):
oracle = record_observation(oracle, 10.0, float(t))
# Short window should give ~10.0
short_twap = compute_twap(oracle, window=5.0)
assert short_twap == pytest.approx(10.0, abs=0.1)
# Long window should be between 1 and 10
long_twap = compute_twap(oracle, window=100.0)
assert 1.0 < long_twap < 10.0
def test_insufficient_data(self):
"""Should return 0.0 with fewer than min_observations."""
oracle = create_oracle()
assert compute_twap(oracle) == 0.0
oracle = record_observation(oracle, 5.0, 0.0)
assert compute_twap(oracle) == 0.0 # Only 1, need 2
def test_default_window(self):
"""Uses default_window from params when window=None."""
params = TWAPOracleParams(default_window=5.0)
oracle = create_oracle(params)
for t in range(20):
oracle = record_observation(oracle, float(t), float(t))
twap_default = compute_twap(oracle)
twap_explicit = compute_twap(oracle, window=5.0)
assert twap_default == twap_explicit
def test_two_observations(self):
"""Minimum case: two observations."""
oracle = create_oracle()
oracle = record_observation(oracle, 2.0, 0.0)
oracle = record_observation(oracle, 4.0, 1.0)
twap = compute_twap(oracle, window=10.0)
# Only one interval: price 2.0 for dt=1.0
assert twap == pytest.approx(2.0, abs=0.01)
class TestVWAP:
def test_equal_volume(self):
"""With equal volumes, VWAP = simple average."""
oracle = create_oracle()
oracle = record_observation(oracle, 10.0, 0.0, volume=100.0)
oracle = record_observation(oracle, 20.0, 1.0, volume=100.0)
vwap = compute_vwap(oracle, window=10.0)
assert vwap == pytest.approx(15.0, abs=0.01)
def test_volume_weighting(self):
"""Higher volume at lower price should pull VWAP down."""
oracle = create_oracle()
oracle = record_observation(oracle, 10.0, 0.0, volume=900.0)
oracle = record_observation(oracle, 20.0, 1.0, volume=100.0)
vwap = compute_vwap(oracle, window=10.0)
# (10*900 + 20*100) / 1000 = 11.0
assert vwap == pytest.approx(11.0, abs=0.01)
def test_zero_volume(self):
"""VWAP returns 0.0 when total volume is zero."""
oracle = create_oracle()
oracle = record_observation(oracle, 10.0, 0.0, volume=0.0)
oracle = record_observation(oracle, 20.0, 1.0, volume=0.0)
assert compute_vwap(oracle, window=10.0) == 0.0
def test_insufficient_data(self):
oracle = create_oracle()
assert compute_vwap(oracle) == 0.0
class TestDeviation:
def test_spot_above_twap(self):
"""Positive deviation when spot > TWAP."""
oracle = create_oracle()
for t in range(10):
oracle = record_observation(oracle, 10.0, float(t))
dev = spot_vs_twap_deviation(oracle, current_spot=12.0, window=100.0)
assert dev > 0
assert dev == pytest.approx(0.2, abs=0.01)
def test_spot_below_twap(self):
"""Negative deviation when spot < TWAP."""
oracle = create_oracle()
for t in range(10):
oracle = record_observation(oracle, 10.0, float(t))
dev = spot_vs_twap_deviation(oracle, current_spot=8.0, window=100.0)
assert dev < 0
assert dev == pytest.approx(-0.2, abs=0.01)
def test_spot_equals_twap(self):
"""Zero deviation when spot == TWAP."""
oracle = create_oracle()
for t in range(10):
oracle = record_observation(oracle, 10.0, float(t))
dev = spot_vs_twap_deviation(oracle, current_spot=10.0, window=100.0)
assert abs(dev) < 1e-10
def test_no_data(self):
"""Returns 0.0 with no TWAP data."""
oracle = create_oracle()
assert spot_vs_twap_deviation(oracle, 10.0) == 0.0
class TestVolatility:
def test_constant_price_zero_vol(self):
"""Constant price = zero volatility."""
oracle = create_oracle()
for t in range(20):
oracle = record_observation(oracle, 100.0, float(t))
vol = get_volatility(oracle, window=100.0)
assert vol == pytest.approx(0.0, abs=1e-10)
def test_varying_price_positive_vol(self):
"""Varying prices = positive volatility."""
oracle = create_oracle()
prices = [100, 105, 98, 110, 95, 108, 102, 112, 97, 106]
for t, p in enumerate(prices):
oracle = record_observation(oracle, float(p), float(t))
vol = get_volatility(oracle, window=100.0)
assert vol > 0
def test_insufficient_data(self):
"""Need at least 3 observations for volatility."""
oracle = create_oracle()
oracle = record_observation(oracle, 100.0, 0.0)
oracle = record_observation(oracle, 105.0, 1.0)
assert get_volatility(oracle) == 0.0
def test_volatility_increases_with_swings(self):
"""Larger price swings = higher volatility."""
# Small swings
oracle_small = create_oracle()
for t in range(20):
price = 100.0 + (1 if t % 2 == 0 else -1)
oracle_small = record_observation(oracle_small, price, float(t))
# Large swings
oracle_large = create_oracle()
for t in range(20):
price = 100.0 + (20 if t % 2 == 0 else -20)
oracle_large = record_observation(oracle_large, price, float(t))
vol_small = get_volatility(oracle_small, window=100.0)
vol_large = get_volatility(oracle_large, window=100.0)
assert vol_large > vol_small
class TestRingBuffer:
def test_max_observations_respected(self):
"""Ring buffer evicts oldest when full."""
params = TWAPOracleParams(max_observations=5)
oracle = create_oracle(params)
for t in range(10):
oracle = record_observation(oracle, float(t), float(t))
assert len(oracle.observations) == 5
# Oldest should be t=5 (indices 5-9 kept)
assert oracle.observations[0].time == 5.0
assert oracle.observations[-1].time == 9.0
def test_small_buffer(self):
"""Even buffer size 2 works."""
params = TWAPOracleParams(max_observations=2, min_observations=2)
oracle = create_oracle(params)
oracle = record_observation(oracle, 1.0, 0.0)
oracle = record_observation(oracle, 2.0, 1.0)
oracle = record_observation(oracle, 3.0, 2.0)
assert len(oracle.observations) == 2
assert oracle.observations[0].price == 2.0
assert oracle.observations[1].price == 3.0
def test_cumulative_survives_eviction(self):
"""Cumulative accumulator keeps accruing even after eviction."""
params = TWAPOracleParams(max_observations=3)
oracle = create_oracle(params)
for t in range(6):
oracle = record_observation(oracle, 10.0, float(t))
# cumulative = 10*1 + 10*1 + 10*1 + 10*1 + 10*1 = 50.0
assert oracle.cumulative_price_time == 50.0
assert len(oracle.observations) == 3