101 lines
3.1 KiB
Python
101 lines
3.1 KiB
Python
"""Tests for dynamic weights."""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from src.primitives.dynamic_weights import (
|
|
GradualChange, GradualWeightSchedule,
|
|
OracleMultiplier, OracleWeightSystem,
|
|
create_lbp_schedule, create_oracle_system,
|
|
simulate_lbp,
|
|
)
|
|
|
|
|
|
class TestGradualChange:
|
|
def test_before_start(self):
|
|
gc = GradualChange(0.8, 0.5, 100.0, 200.0)
|
|
assert gc.value_at(50.0) == 0.8
|
|
|
|
def test_at_start(self):
|
|
gc = GradualChange(0.8, 0.5, 100.0, 200.0)
|
|
assert gc.value_at(100.0) == 0.8
|
|
|
|
def test_midpoint(self):
|
|
gc = GradualChange(0.8, 0.5, 100.0, 200.0)
|
|
assert abs(gc.value_at(150.0) - 0.65) < 1e-10
|
|
|
|
def test_at_end(self):
|
|
gc = GradualChange(0.8, 0.5, 100.0, 200.0)
|
|
assert gc.value_at(200.0) == 0.5
|
|
|
|
def test_after_end(self):
|
|
gc = GradualChange(0.8, 0.5, 100.0, 200.0)
|
|
assert gc.value_at(300.0) == 0.5
|
|
|
|
|
|
class TestWeightSchedule:
|
|
def test_lbp_schedule(self):
|
|
schedule = create_lbp_schedule(
|
|
2,
|
|
np.array([0.9, 0.1]),
|
|
np.array([0.5, 0.5]),
|
|
start_time=0, end_time=100,
|
|
)
|
|
# Start: 90/10
|
|
w_start = schedule.weights_at(0)
|
|
assert abs(w_start[0] - 0.9) < 1e-10
|
|
|
|
# End: 50/50
|
|
w_end = schedule.weights_at(100)
|
|
assert abs(w_end[0] - 0.5) < 1e-10
|
|
|
|
# Mid: 70/30
|
|
w_mid = schedule.weights_at(50)
|
|
assert abs(w_mid[0] - 0.7) < 1e-10
|
|
|
|
def test_weights_always_sum_to_one(self):
|
|
schedule = create_lbp_schedule(
|
|
3,
|
|
np.array([0.6, 0.3, 0.1]),
|
|
np.array([0.33, 0.34, 0.33]),
|
|
start_time=0, end_time=100,
|
|
)
|
|
for t in np.linspace(0, 100, 20):
|
|
w = schedule.weights_at(t)
|
|
assert abs(sum(w) - 1.0) < 1e-10
|
|
|
|
|
|
class TestOracleWeights:
|
|
def test_initial_weights(self):
|
|
system = create_oracle_system(3, np.array([0.5, 0.3, 0.2]))
|
|
w = system.weights_at(0)
|
|
assert abs(w[0] - 0.5) < 1e-10
|
|
assert abs(w[1] - 0.3) < 1e-10
|
|
|
|
def test_multiplier_drift(self):
|
|
system = create_oracle_system(2, np.array([0.5, 0.5]))
|
|
# Set multiplier: token 0 increases, token 1 decreases
|
|
system = system.update_all(np.array([0.01, -0.01]), time=0)
|
|
w_later = system.weights_at(10)
|
|
# Token 0 should now have higher weight
|
|
assert w_later[0] > 0.5
|
|
assert w_later[1] < 0.5
|
|
|
|
def test_clamped_to_bounds(self):
|
|
mult = OracleMultiplier(0.5, 1.0, 0) # Very fast drift
|
|
# At t=100, raw would be 100.5 — should clamp to 0.99
|
|
assert mult.weight_at(100) == 0.99
|
|
|
|
|
|
class TestSimulate:
|
|
def test_lbp_simulation(self):
|
|
schedule = create_lbp_schedule(
|
|
2, np.array([0.9, 0.1]), np.array([0.5, 0.5]),
|
|
start_time=0, end_time=100,
|
|
)
|
|
result = simulate_lbp(schedule, 0, 100, n_steps=50)
|
|
assert result["times"].shape == (50,)
|
|
assert result["weights"].shape == (50, 2)
|
|
# First step: ~90/10, last step: 50/50
|
|
assert result["weights"][0, 0] > 0.85
|
|
assert abs(result["weights"][-1, 0] - 0.5) < 0.02
|