type annotations: simulation.py
This commit is contained in:
parent
2fb0dcf754
commit
9d9e33b766
|
|
@ -129,4 +129,4 @@ for raw_result, tensor_field in run2.main():
|
|||
The above can be run in Jupyter.
|
||||
```bash
|
||||
jupyter notebook
|
||||
```
|
||||
```
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ class ExecutionContext:
|
|||
result = simulation(var_dict, states_list, config, env_processes, T, N)
|
||||
return flatten(result)
|
||||
|
||||
def parallelize_simulations(fs, var_dict_list, states_list, configs, env_processes, Ts, Ns):
|
||||
l = list(zip(fs, var_dict_list, states_list, configs, env_processes, Ts, Ns))
|
||||
def parallelize_simulations(simulations, var_dict_list, states_list, configs, env_processes, Ts, Ns):
|
||||
l = list(zip(simulations, var_dict_list, states_list, configs, env_processes, Ts, Ns))
|
||||
with Pool(len(configs)) as p:
|
||||
results = p.map(lambda t: t[0](t[1], t[2], t[3], t[4], t[5], t[6]), l)
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -2,19 +2,38 @@ from copy import deepcopy
|
|||
from fn.op import foldr, call
|
||||
|
||||
from cadCAD.engine.utils import engine_exception
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
id_exception = engine_exception(KeyError, KeyError, None)
|
||||
id_exception: Callable = engine_exception(KeyError, KeyError, None)
|
||||
|
||||
import pprint as pp
|
||||
|
||||
|
||||
class Executor:
|
||||
|
||||
def __init__(self, policy_ops, policy_update_exception=id_exception, state_update_exception=id_exception):
|
||||
self.policy_ops = policy_ops # behavior_ops
|
||||
self.state_update_exception = state_update_exception
|
||||
self.policy_update_exception = policy_update_exception # behavior_update_exception
|
||||
def __init__(
|
||||
self,
|
||||
policy_ops: List[Callable],
|
||||
policy_update_exception: Callable = id_exception,
|
||||
state_update_exception: Callable = id_exception
|
||||
) -> None:
|
||||
|
||||
# behavior_ops
|
||||
self.policy_ops = policy_ops
|
||||
self.state_update_exception = state_update_exception
|
||||
self.policy_update_exception = policy_update_exception
|
||||
# behavior_update_exception
|
||||
|
||||
# get_behavior_input # sL: State Window
|
||||
def get_policy_input(
|
||||
self,
|
||||
var_dict: Dict[str, List[Any]],
|
||||
sub_step: int,
|
||||
sL: List[Dict[str, Any]],
|
||||
s: Dict[str, Any],
|
||||
funcs: List[Callable]
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
# get_behavior_input
|
||||
def get_policy_input(self, var_dict, sub_step, sL, s, funcs):
|
||||
ops = self.policy_ops[::-1]
|
||||
|
||||
def get_col_results(var_dict, sub_step, sL, s, funcs):
|
||||
|
|
@ -22,23 +41,39 @@ class Executor:
|
|||
|
||||
return foldr(call, get_col_results(var_dict, sub_step, sL, s, funcs))(ops)
|
||||
|
||||
def apply_env_proc(self, env_processes, state_dict, sub_step):
|
||||
def apply_env_proc(
|
||||
self,
|
||||
env_processes: Dict[str, Callable],
|
||||
state_dict: Dict[str, Any],
|
||||
sub_step: int
|
||||
) -> None:
|
||||
for state in state_dict.keys():
|
||||
if state in list(env_processes.keys()):
|
||||
env_state = env_processes[state]
|
||||
env_state: Callable = env_processes[state]
|
||||
if (env_state.__name__ == '_curried') or (env_state.__name__ == 'proc_trigger'):
|
||||
state_dict[state] = env_state(sub_step)(state_dict[state])
|
||||
state_dict[state]: Any = env_state(sub_step)(state_dict[state])
|
||||
else:
|
||||
state_dict[state] = env_state(state_dict[state])
|
||||
state_dict[state]: Any = env_state(state_dict[state])
|
||||
|
||||
# mech_step
|
||||
def partial_state_update(self, var_dict, sub_step, sL, state_funcs, policy_funcs, env_processes, time_step, run):
|
||||
last_in_obj = sL[-1]
|
||||
def partial_state_update(
|
||||
self,
|
||||
var_dict: Dict[str, List[Any]],
|
||||
sub_step: int,
|
||||
sL: Any,
|
||||
state_funcs: List[Callable],
|
||||
policy_funcs: List[Callable],
|
||||
env_processes: Dict[str, Callable],
|
||||
time_step: int,
|
||||
run: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
||||
_input = self.policy_update_exception(self.get_policy_input(var_dict, sub_step, sL, last_in_obj, policy_funcs))
|
||||
last_in_obj: Dict[str, Any] = sL[-1]
|
||||
|
||||
_input: Dict[str, Any] = self.policy_update_exception(self.get_policy_input(var_dict, sub_step, sL, last_in_obj, policy_funcs))
|
||||
|
||||
# ToDo: add env_proc generator to `last_in_copy` iterator as wrapper function
|
||||
last_in_copy = dict(
|
||||
last_in_copy: Dict[str, Any] = dict(
|
||||
[
|
||||
self.state_update_exception(f(var_dict, sub_step, sL, last_in_obj, _input)) for f in state_funcs
|
||||
]
|
||||
|
|
@ -46,58 +81,91 @@ class Executor:
|
|||
|
||||
for k in last_in_obj:
|
||||
if k not in last_in_copy:
|
||||
last_in_copy[k] = last_in_obj[k]
|
||||
last_in_copy[k]: Any = last_in_obj[k]
|
||||
|
||||
del last_in_obj
|
||||
|
||||
self.apply_env_proc(env_processes, last_in_copy, last_in_copy['timestep'])
|
||||
|
||||
last_in_copy['substep'], last_in_copy['timestep'], last_in_copy['run'] = sub_step, time_step, run
|
||||
|
||||
sL.append(last_in_copy)
|
||||
del last_in_copy
|
||||
|
||||
return sL
|
||||
|
||||
|
||||
# mech_pipeline
|
||||
def state_update_pipeline(self, var_dict, states_list, configs, env_processes, time_step, run):
|
||||
def state_update_pipeline(
|
||||
self,
|
||||
var_dict: Dict[str, List[Any]],
|
||||
states_list: List[Dict[str, Any]],
|
||||
configs: List[Tuple[List[Callable], List[Callable]]],
|
||||
env_processes: Dict[str, Callable],
|
||||
time_step: int,
|
||||
run: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
||||
sub_step = 0
|
||||
states_list_copy = deepcopy(states_list)
|
||||
genesis_states = states_list_copy[-1]
|
||||
states_list_copy: List[Dict[str, Any]] = deepcopy(states_list)
|
||||
genesis_states: Dict[str, Any] = states_list_copy[-1]
|
||||
genesis_states['substep'], genesis_states['timestep'] = sub_step, time_step
|
||||
states_list = [genesis_states]
|
||||
states_list: List[Dict[str, Any]] = [genesis_states]
|
||||
|
||||
sub_step += 1
|
||||
for config in configs:
|
||||
s_conf, p_conf = config[0], config[1]
|
||||
states_list = self.partial_state_update(var_dict, sub_step, states_list, s_conf, p_conf, env_processes, time_step, run)
|
||||
states_list: List[Dict[str, Any]] = self.partial_state_update(
|
||||
var_dict, sub_step, states_list, s_conf, p_conf, env_processes, time_step, run
|
||||
)
|
||||
|
||||
sub_step += 1
|
||||
|
||||
time_step += 1
|
||||
|
||||
return states_list
|
||||
|
||||
def run_pipeline(self, var_dict, states_list, configs, env_processes, time_seq, run):
|
||||
time_seq = [x + 1 for x in time_seq]
|
||||
simulation_list = [states_list]
|
||||
def run_pipeline(
|
||||
self,
|
||||
var_dict: Dict[str, List[Any]],
|
||||
states_list: List[Dict[str, Any]],
|
||||
configs: List[Tuple[List[Callable], List[Callable]]],
|
||||
env_processes: Dict[str, Callable],
|
||||
time_seq: range,
|
||||
run: int
|
||||
) -> List[List[Dict[str, Any]]]:
|
||||
|
||||
time_seq: List[int] = [x + 1 for x in time_seq]
|
||||
simulation_list: List[List[Dict[str, Any]]] = [states_list]
|
||||
for time_step in time_seq:
|
||||
pipe_run = self.state_update_pipeline(var_dict, simulation_list[-1], configs, env_processes, time_step, run)
|
||||
pipe_run: List[Dict[str, Any]] = self.state_update_pipeline(
|
||||
var_dict, simulation_list[-1], configs, env_processes, time_step, run
|
||||
)
|
||||
_, *pipe_run = pipe_run
|
||||
simulation_list.append(pipe_run)
|
||||
|
||||
return simulation_list
|
||||
|
||||
# ToDo: Muiltithreaded Runs
|
||||
def simulation(self, var_dict, states_list, configs, env_processes, time_seq, runs):
|
||||
pipe_run = []
|
||||
def simulation(
|
||||
self,
|
||||
var_dict: Dict[str, List[Any]],
|
||||
states_list: List[Dict[str, Any]],
|
||||
configs: List[Tuple[List[Callable], List[Callable]]],
|
||||
env_processes: Dict[str, Callable],
|
||||
time_seq: range,
|
||||
runs: int
|
||||
) -> List[List[Dict[str, Any]]]:
|
||||
|
||||
pipe_run: List[List[Dict[str, Any]]] = []
|
||||
for run in range(runs):
|
||||
run += 1
|
||||
states_list_copy = deepcopy(states_list)
|
||||
states_list_copy: List[Dict[str, Any]] = deepcopy(states_list)
|
||||
head, *tail = self.run_pipeline(var_dict, states_list_copy, configs, env_processes, time_seq, run)
|
||||
genesis = head.pop()
|
||||
genesis['substep'], genesis['timestep'], genesis['run'] = 0, 0, run
|
||||
first_timestep_per_run = [genesis] + tail.pop(0)
|
||||
pipe_run += [first_timestep_per_run] + tail
|
||||
del states_list_copy
|
||||
|
||||
return pipe_run
|
||||
genesis: Dict[str, Any] = head.pop()
|
||||
genesis['substep'], genesis['timestep'], genesis['run'] = 0, 0, run
|
||||
first_timestep_per_run: List[Dict[str, Any]] = [genesis] + tail.pop(0)
|
||||
pipe_run += [first_timestep_per_run] + tail
|
||||
|
||||
return pipe_run
|
||||
|
|
|
|||
|
|
@ -39,4 +39,4 @@ def engine_exception(ErrorType, error_message, exception_function, try_function)
|
|||
def fit_param(param, x):
|
||||
return x + param
|
||||
|
||||
# fit_param = lambda param: lambda x: x + param
|
||||
# fit_param = lambda param: lambda x: x + param
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from collections import defaultdict
|
||||
from itertools import product
|
||||
import warnings
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def pipe(x):
|
||||
return x
|
||||
|
|
@ -41,11 +43,11 @@ def dict_filter(dictionary, condition):
|
|||
return dict([(k, v) for k, v in dictionary.items() if condition(v)])
|
||||
|
||||
|
||||
def get_max_dict_val_len(g):
|
||||
def get_max_dict_val_len(g: Dict[str, List[int]]) -> int:
|
||||
return len(max(g.values(), key=len))
|
||||
|
||||
|
||||
def tabulate_dict(d):
|
||||
def tabulate_dict(d: Dict[str, List[int]]) -> Dict[str, List[int]]:
|
||||
max_len = get_max_dict_val_len(d)
|
||||
_d = {}
|
||||
for k, vl in d.items():
|
||||
|
|
@ -57,7 +59,7 @@ def tabulate_dict(d):
|
|||
return _d
|
||||
|
||||
|
||||
def flatten_tabulated_dict(d):
|
||||
def flatten_tabulated_dict(d: Dict[str, List[int]]) -> List[Dict[str, int]]:
|
||||
max_len = get_max_dict_val_len(d)
|
||||
dl = [{} for i in range(max_len)]
|
||||
|
||||
|
|
@ -133,4 +135,4 @@ def curry_pot(f, *argv):
|
|||
# def decorator(f):
|
||||
# f.__name__ = newname
|
||||
# return f
|
||||
# return decorator
|
||||
# return decorator
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -21,6 +21,8 @@ def p1m1(_g, step, sL, s):
|
|||
def p2m1(_g, step, sL, s):
|
||||
return {'param2': 4}
|
||||
|
||||
# []
|
||||
|
||||
def p1m2(_g, step, sL, s):
|
||||
return {'param1': 'a', 'param2': 2}
|
||||
def p2m2(_g, step, sL, s):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from cadCAD.configuration import append_configs
|
|||
from cadCAD.configuration.utils import proc_trigger, ep_time_step
|
||||
from cadCAD.configuration.utils.parameterSweep import config_sim
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
seeds = {
|
||||
|
|
@ -17,7 +19,7 @@ seeds = {
|
|||
}
|
||||
|
||||
|
||||
g = {
|
||||
g: Dict[str, List[int]] = {
|
||||
'alpha': [1],
|
||||
'beta': [2, 5],
|
||||
'gamma': [3, 4],
|
||||
|
|
|
|||
Loading…
Reference in New Issue