myco-bonding-curve/tests/test_dynamic_weights.py

101 lines
3.1 KiB
Python

"""Tests for dynamic weights."""
import numpy as np
import pytest
from src.primitives.dynamic_weights import (
GradualChange, GradualWeightSchedule,
OracleMultiplier, OracleWeightSystem,
create_lbp_schedule, create_oracle_system,
simulate_lbp,
)
class TestGradualChange:
def test_before_start(self):
gc = GradualChange(0.8, 0.5, 100.0, 200.0)
assert gc.value_at(50.0) == 0.8
def test_at_start(self):
gc = GradualChange(0.8, 0.5, 100.0, 200.0)
assert gc.value_at(100.0) == 0.8
def test_midpoint(self):
gc = GradualChange(0.8, 0.5, 100.0, 200.0)
assert abs(gc.value_at(150.0) - 0.65) < 1e-10
def test_at_end(self):
gc = GradualChange(0.8, 0.5, 100.0, 200.0)
assert gc.value_at(200.0) == 0.5
def test_after_end(self):
gc = GradualChange(0.8, 0.5, 100.0, 200.0)
assert gc.value_at(300.0) == 0.5
class TestWeightSchedule:
def test_lbp_schedule(self):
schedule = create_lbp_schedule(
2,
np.array([0.9, 0.1]),
np.array([0.5, 0.5]),
start_time=0, end_time=100,
)
# Start: 90/10
w_start = schedule.weights_at(0)
assert abs(w_start[0] - 0.9) < 1e-10
# End: 50/50
w_end = schedule.weights_at(100)
assert abs(w_end[0] - 0.5) < 1e-10
# Mid: 70/30
w_mid = schedule.weights_at(50)
assert abs(w_mid[0] - 0.7) < 1e-10
def test_weights_always_sum_to_one(self):
schedule = create_lbp_schedule(
3,
np.array([0.6, 0.3, 0.1]),
np.array([0.33, 0.34, 0.33]),
start_time=0, end_time=100,
)
for t in np.linspace(0, 100, 20):
w = schedule.weights_at(t)
assert abs(sum(w) - 1.0) < 1e-10
class TestOracleWeights:
def test_initial_weights(self):
system = create_oracle_system(3, np.array([0.5, 0.3, 0.2]))
w = system.weights_at(0)
assert abs(w[0] - 0.5) < 1e-10
assert abs(w[1] - 0.3) < 1e-10
def test_multiplier_drift(self):
system = create_oracle_system(2, np.array([0.5, 0.5]))
# Set multiplier: token 0 increases, token 1 decreases
system = system.update_all(np.array([0.01, -0.01]), time=0)
w_later = system.weights_at(10)
# Token 0 should now have higher weight
assert w_later[0] > 0.5
assert w_later[1] < 0.5
def test_clamped_to_bounds(self):
mult = OracleMultiplier(0.5, 1.0, 0) # Very fast drift
# At t=100, raw would be 100.5 — should clamp to 0.99
assert mult.weight_at(100) == 0.99
class TestSimulate:
def test_lbp_simulation(self):
schedule = create_lbp_schedule(
2, np.array([0.9, 0.1]), np.array([0.5, 0.5]),
start_time=0, end_time=100,
)
result = simulate_lbp(schedule, 0, 100, n_steps=50)
assert result["times"].shape == (50,)
assert result["weights"].shape == (50, 2)
# First step: ~90/10, last step: 50/50
assert result["weights"][0, 0] > 0.85
assert abs(result["weights"][-1, 0] - 0.5) < 0.02