parent
a0160d7606
commit
27ed2c9031
|
|
@ -8,7 +8,7 @@ from numpy.random import RandomState
|
|||
from SimCAD.utils import key_filter
|
||||
from SimCAD.configuration.utils.behaviorAggregation import dict_elemwise_sum
|
||||
|
||||
|
||||
#Configuration(sim_config, state_dict, seed, exogenous_states, env_processes, mechanisms)
|
||||
class Configuration:
|
||||
def __init__(self,
|
||||
sim_config,
|
||||
|
|
@ -26,7 +26,6 @@ class Configuration:
|
|||
self.behavior_ops = behavior_ops
|
||||
self.mechanisms = mechanisms
|
||||
|
||||
|
||||
class Identity:
|
||||
def __init__(self, behavior_id={'indentity': 0}):
|
||||
self.beh_id_return_val = behavior_id
|
||||
|
|
@ -59,11 +58,11 @@ class Processor:
|
|||
self.state_identity = id.state_identity
|
||||
self.apply_identity_funcs = id.apply_identity_funcs
|
||||
|
||||
# Make returntype chosen by user.
|
||||
# Make returntype chosen by user. Must Classify Configs
|
||||
def create_matrix_field(self, mechanisms, key):
|
||||
if key == 'states':
|
||||
identity = self.state_identity
|
||||
elif key == 'behaviors':
|
||||
else:
|
||||
identity = self.behavior_identity
|
||||
df = pd.DataFrame(key_filter(mechanisms, key))
|
||||
col_list = self.apply_identity_funcs(identity, df, list(df.columns))
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ class TensorFieldReport:
|
|||
def __init__(self, config_proc):
|
||||
self.config_proc = config_proc
|
||||
|
||||
# ??? dont for-loop to apply exo_procs, use exo_proc struct
|
||||
# dont for-loop to apply exo_procs, use exo_proc struct
|
||||
def create_tensor_field(self, mechanisms, exo_proc, keys=['behaviors', 'states']):
|
||||
dfs = [self.config_proc.create_matrix_field(mechanisms, k) for k in keys]
|
||||
df = pd.concat(dfs, axis=1)
|
||||
|
|
|
|||
|
|
@ -69,14 +69,13 @@ class Executor:
|
|||
|
||||
# Dimensions: N x r x mechs
|
||||
|
||||
if self.exec_context == ExecutionMode.single_proc:
|
||||
tensor_field = create_tensor_field(mechanisms.pop(), eps.pop())
|
||||
result = self.exec_method(simulation_execs, states_lists, configs_structs, env_processes_list, Ts, Ns)
|
||||
return (result, tensor_field)
|
||||
elif self.exec_context == ExecutionMode.multi_proc:
|
||||
if self.exec_context == ExecutionMode.multi_proc:
|
||||
if len(self.configs) > 1:
|
||||
simulations = self.exec_method(simulation_execs, states_lists, configs_structs, env_processes_list, Ts, Ns)
|
||||
results = []
|
||||
for result, mechanism, ep in list(zip(simulations, mechanisms, eps)):
|
||||
results.append((flatten(result), create_tensor_field(mechanism, ep)))
|
||||
return results
|
||||
print(tabulate(create_tensor_field(mechanism, ep), headers='keys', tablefmt='psql'))
|
||||
results.append(flatten(result))
|
||||
return results
|
||||
else:
|
||||
return self.exec_method(simulation_execs, states_lists, configs_structs, env_processes_list, Ts, Ns)
|
||||
|
|
@ -20,13 +20,9 @@ print()
|
|||
first_config = [configs[0]] # from config1
|
||||
single_proc_ctx = ExecutionContext(context=exec_mode.single_proc)
|
||||
run1 = Executor(exec_context=single_proc_ctx, configs=first_config)
|
||||
run1_raw_result, tensor_field = run1.main()
|
||||
run1_raw_result = run1.main()
|
||||
result = pd.DataFrame(run1_raw_result)
|
||||
# result.to_csv('~/Projects/DiffyQ-SimCAD/results/config4.csv', sep=',')
|
||||
print()
|
||||
print("Tensor Field:")
|
||||
print(tabulate(tensor_field, headers='keys', tablefmt='psql'))
|
||||
print("Output:")
|
||||
print(tabulate(result, headers='keys', tablefmt='psql'))
|
||||
print()
|
||||
|
||||
|
|
@ -34,11 +30,8 @@ print("Simulation Execution 2: Pairwise Execution")
|
|||
print()
|
||||
multi_proc_ctx = ExecutionContext(context=exec_mode.multi_proc)
|
||||
run2 = Executor(exec_context=multi_proc_ctx, configs=configs)
|
||||
for raw_result, tensor_field in run2.main():
|
||||
run2_raw_results = run2.main()
|
||||
for raw_result in run2_raw_results:
|
||||
result = pd.DataFrame(raw_result)
|
||||
print()
|
||||
print("Tensor Field:")
|
||||
print(tabulate(tensor_field, headers='keys', tablefmt='psql'))
|
||||
print("Output:")
|
||||
print(tabulate(result, headers='keys', tablefmt='psql'))
|
||||
print()
|
||||
print()
|
||||
Loading…
Reference in New Issue