"""Tests for dynamic weights.""" import numpy as np import pytest from src.primitives.dynamic_weights import ( GradualChange, GradualWeightSchedule, OracleMultiplier, OracleWeightSystem, create_lbp_schedule, create_oracle_system, simulate_lbp, ) class TestGradualChange: def test_before_start(self): gc = GradualChange(0.8, 0.5, 100.0, 200.0) assert gc.value_at(50.0) == 0.8 def test_at_start(self): gc = GradualChange(0.8, 0.5, 100.0, 200.0) assert gc.value_at(100.0) == 0.8 def test_midpoint(self): gc = GradualChange(0.8, 0.5, 100.0, 200.0) assert abs(gc.value_at(150.0) - 0.65) < 1e-10 def test_at_end(self): gc = GradualChange(0.8, 0.5, 100.0, 200.0) assert gc.value_at(200.0) == 0.5 def test_after_end(self): gc = GradualChange(0.8, 0.5, 100.0, 200.0) assert gc.value_at(300.0) == 0.5 class TestWeightSchedule: def test_lbp_schedule(self): schedule = create_lbp_schedule( 2, np.array([0.9, 0.1]), np.array([0.5, 0.5]), start_time=0, end_time=100, ) # Start: 90/10 w_start = schedule.weights_at(0) assert abs(w_start[0] - 0.9) < 1e-10 # End: 50/50 w_end = schedule.weights_at(100) assert abs(w_end[0] - 0.5) < 1e-10 # Mid: 70/30 w_mid = schedule.weights_at(50) assert abs(w_mid[0] - 0.7) < 1e-10 def test_weights_always_sum_to_one(self): schedule = create_lbp_schedule( 3, np.array([0.6, 0.3, 0.1]), np.array([0.33, 0.34, 0.33]), start_time=0, end_time=100, ) for t in np.linspace(0, 100, 20): w = schedule.weights_at(t) assert abs(sum(w) - 1.0) < 1e-10 class TestOracleWeights: def test_initial_weights(self): system = create_oracle_system(3, np.array([0.5, 0.3, 0.2])) w = system.weights_at(0) assert abs(w[0] - 0.5) < 1e-10 assert abs(w[1] - 0.3) < 1e-10 def test_multiplier_drift(self): system = create_oracle_system(2, np.array([0.5, 0.5])) # Set multiplier: token 0 increases, token 1 decreases system = system.update_all(np.array([0.01, -0.01]), time=0) w_later = system.weights_at(10) # Token 0 should now have higher weight assert w_later[0] > 0.5 assert w_later[1] < 0.5 def test_clamped_to_bounds(self): mult = OracleMultiplier(0.5, 1.0, 0) # Very fast drift # At t=100, raw would be 100.5 — should clamp to 0.99 assert mult.weight_at(100) == 0.99 class TestSimulate: def test_lbp_simulation(self): schedule = create_lbp_schedule( 2, np.array([0.9, 0.1]), np.array([0.5, 0.5]), start_time=0, end_time=100, ) result = simulate_lbp(schedule, 0, 100, n_steps=50) assert result["times"].shape == (50,) assert result["weights"].shape == (50, 2) # First step: ~90/10, last step: 50/50 assert result["weights"][0, 0] > 0.85 assert abs(result["weights"][-1, 0] - 0.5) < 0.02