"""Tests for imbalance fees.""" import numpy as np import pytest from src.primitives.imbalance_fees import ( compute_imbalance, surge_fee, compute_fee_adjusted_output, optimal_deposit_fee, ) class TestImbalance: def test_balanced_is_zero(self): assert compute_imbalance(np.array([100.0, 100.0, 100.0])) == 0.0 def test_imbalanced_is_positive(self): assert compute_imbalance(np.array([100.0, 200.0, 300.0])) > 0 def test_extreme_imbalance_near_one(self): imb = compute_imbalance(np.array([1.0, 1.0, 1000.0])) assert imb > 0.5 def test_two_tokens(self): imb = compute_imbalance(np.array([100.0, 200.0])) # median = 150, sum|dev| = 50+50 = 100, total = 300 assert abs(imb - 100.0 / 300.0) < 1e-10 class TestSurgeFee: def test_balanced_gets_static_fee(self): before = np.array([100.0, 100.0]) after = np.array([110.0, 90.0]) # Small imbalance fee = surge_fee(before, after, threshold=0.5) assert fee == 0.003 # Static def test_worsening_imbalance_gets_surge(self): before = np.array([100.0, 100.0, 100.0]) after = np.array([200.0, 50.0, 50.0]) # Heavy imbalance fee = surge_fee(before, after, threshold=0.1) assert fee > 0.003 # Should be surged def test_improving_imbalance_no_surge(self): before = np.array([200.0, 50.0, 50.0]) after = np.array([150.0, 75.0, 75.0]) # Improving fee = surge_fee(before, after, threshold=0.1) assert fee == 0.003 # Static (improving) class TestOptimalFees: def test_underweight_gets_discount(self): fees = optimal_deposit_fee( np.array([500.0, 1000.0, 1000.0]), np.array([100.0, 100.0, 100.0]), ) assert fees[0] < fees[1] # Underweight asset cheaper assert fees[0] < fees[2] def test_balanced_gets_base_fee(self): fees = optimal_deposit_fee( np.array([1000.0, 1000.0]), np.array([100.0, 100.0]), ) np.testing.assert_allclose(fees, [0.001, 0.001], atol=1e-10)