"""Tests for StableSwap invariant.""" import numpy as np import pytest from src.primitives.stableswap import ( compute_invariant, calc_out_given_in, calc_in_given_out, spot_price, ) class TestInvariant: def test_balanced_pool_D_equals_sum(self): """For balanced pools with high A, D ≈ sum of balances.""" balances = np.array([1000.0, 1000.0]) D = compute_invariant(balances, amp=1000) assert abs(D - 2000.0) < 1.0 # Very close to sum def test_balanced_three_tokens(self): """Three-token balanced pool.""" balances = np.array([1000.0, 1000.0, 1000.0]) D = compute_invariant(balances, amp=1000) assert abs(D - 3000.0) < 1.0 def test_high_amp_approaches_constant_sum(self): """As A increases, D approaches sum(balances) even for imbalanced pools.""" balances = np.array([900.0, 1100.0]) D_low = compute_invariant(balances, amp=10) D_high = compute_invariant(balances, amp=10000) assert abs(D_high - 2000.0) < abs(D_low - 2000.0) def test_low_amp_more_slippage(self): """With lower A, swaps show more slippage (more constant-product-like).""" balances = np.array([1000.0, 1000.0]) # Low A should give worse exchange rate than high A out_low_A = calc_out_given_in(balances, amp=1, token_in_index=0, token_out_index=1, amount_in=100.0) out_high_A = calc_out_given_in(balances, amp=5000, token_in_index=0, token_out_index=1, amount_in=100.0) # High A should give more output (less slippage) assert out_high_A > out_low_A def test_homogeneity_degree_1(self): """D(k*b) = k * D(b).""" balances = np.array([500.0, 1500.0]) amp = 200 D_base = compute_invariant(balances, amp) for k in [0.5, 2.0, 5.0]: D_scaled = compute_invariant(k * balances, amp) assert abs(D_scaled - k * D_base) < 1e-6 * k * D_base class TestSwaps: def test_invariant_preserved(self): """Swap preserves D.""" balances = np.array([1000.0, 1000.0]) amp = 500 D_before = compute_invariant(balances, amp) amount_out = calc_out_given_in(balances, amp, 0, 1, 100.0) new_balances = np.array([1100.0, 1000.0 - amount_out]) D_after = compute_invariant(new_balances, amp) assert abs(D_after - D_before) < 1e-6 def test_near_1_to_1_with_high_amp(self): """High-A balanced pool should swap nearly 1:1.""" balances = np.array([1000.0, 1000.0]) amount_out = calc_out_given_in(balances, amp=5000, token_in_index=0, token_out_index=1, amount_in=10.0) # Should be very close to 10.0 for high A assert abs(amount_out - 10.0) < 0.1 def test_round_trip(self): """calc_in_given_out inverts calc_out_given_in.""" balances = np.array([1000.0, 1000.0]) amp = 200 amount_in = 50.0 amount_out = calc_out_given_in(balances, amp, 0, 1, amount_in) recovered = calc_in_given_out(balances, amp, 0, 1, amount_out) assert abs(recovered - amount_in) < 1e-6 def test_large_swap_high_slippage(self): """Large swap on imbalanced pool should show significant slippage.""" balances = np.array([100.0, 1900.0]) # Very imbalanced amp = 100 # Selling more of the scarce token should yield less than 1:1 amount_out = calc_out_given_in(balances, amp, 0, 1, 50.0) # Price should be > 1:1 since token 0 is scarce assert amount_out > 50.0 def test_spot_price_balanced(self): """Spot price should be ~1.0 for balanced pool.""" balances = np.array([1000.0, 1000.0]) sp = spot_price(balances, amp=500, token_in_index=0, token_out_index=1) assert abs(sp - 1.0) < 0.01 def test_three_token_swap(self): """Three-token pool swap preserves D.""" balances = np.array([1000.0, 1000.0, 1000.0]) amp = 500 D_before = compute_invariant(balances, amp) amount_out = calc_out_given_in(balances, amp, 0, 2, 100.0) new_balances = balances.copy() new_balances[0] += 100.0 new_balances[2] -= amount_out D_after = compute_invariant(new_balances, amp) assert abs(D_after - D_before) < 1e-4