from functools import reduce from fn.op import foldr import pandas as pd from SimCAD import configs from SimCAD.utils import key_filter from SimCAD.configuration.utils.policyAggregation import dict_elemwise_sum from SimCAD.configuration.utils import exo_update_per_ts class Configuration(object): def __init__(self, sim_config=None, initial_state=None, seeds=None, env_processes=None, exogenous_states=None, partial_state_updates=None, policy_ops=[foldr(dict_elemwise_sum())]): self.sim_config = sim_config self.initial_state = initial_state self.seeds = seeds self.env_processes = env_processes self.exogenous_states = exogenous_states self.partial_state_updates = partial_state_updates self.policy_ops = policy_ops def append_configs(sim_configs, initial_state, seeds, raw_exogenous_states, env_processes, partial_state_updates, _exo_update_per_ts=True): if _exo_update_per_ts is True: exogenous_states = exo_update_per_ts(raw_exogenous_states) else: exogenous_states = raw_exogenous_states if isinstance(sim_configs, list): for sim_config in sim_configs: configs.append( Configuration( sim_config=sim_config, initial_state=initial_state, seeds=seeds, exogenous_states=exogenous_states, env_processes=env_processes, partial_state_updates=partial_state_updates ) ) elif isinstance(sim_configs, dict): configs.append( Configuration( sim_config=sim_configs, initial_state=initial_state, seeds=seeds, exogenous_states=exogenous_states, env_processes=env_processes, partial_state_updates=partial_state_updates ) ) class Identity: def __init__(self, policy_id={'identity': 0}): self.beh_id_return_val = policy_id def p_identity(self, var_dict, sub_step, sL, s): return self.beh_id_return_val def policy_identity(self, k): return self.p_identity def no_state_identity(self, var_dict, sub_step, sL, s, _input): return None def state_identity(self, k): return lambda var_dict, sub_step, sL, s, _input: (k, s[k]) def apply_identity_funcs(self, identity, df, cols): def fillna_with_id_func(identity, df, col): return df[[col]].fillna(value=identity(col)) return list(map(lambda col: fillna_with_id_func(identity, df, col), cols)) class Processor: def __init__(self, id=Identity()): self.id = id self.p_identity = id.p_identity self.policy_identity = id.policy_identity self.no_state_identity = id.no_state_identity self.state_identity = id.state_identity self.apply_identity_funcs = id.apply_identity_funcs def create_matrix_field(self, partial_state_updates, key): if key == 'states': identity = self.state_identity elif key == 'policies': identity = self.policy_identity df = pd.DataFrame(key_filter(partial_state_updates, key)) col_list = self.apply_identity_funcs(identity, df, list(df.columns)) if len(col_list) != 0: return reduce((lambda x, y: pd.concat([x, y], axis=1)), col_list) else: return pd.DataFrame({'empty': []}) def generate_config(self, initial_state, partial_state_updates, exo_proc): def no_update_handler(bdf, sdf): if (bdf.empty == False) and (sdf.empty == True): bdf_values = bdf.values.tolist() sdf_values = [[self.no_state_identity] * len(bdf_values) for m in range(len(partial_state_updates))] return sdf_values, bdf_values elif (bdf.empty == True) and (sdf.empty == False): sdf_values = sdf.values.tolist() bdf_values = [[self.b_identity] * len(sdf_values) for m in range(len(partial_state_updates))] return sdf_values, bdf_values else: sdf_values = sdf.values.tolist() bdf_values = bdf.values.tolist() return sdf_values, bdf_values def only_ep_handler(state_dict): sdf_functions = [ lambda sub_step, sL, s, _input: (k, v) for k, v in zip(state_dict.keys(), state_dict.values()) ] sdf_values = [sdf_functions] bdf_values = [[self.p_identity] * len(sdf_values)] return sdf_values, bdf_values if len(partial_state_updates) != 0: bdf = self.create_matrix_field(partial_state_updates, 'policies') sdf = self.create_matrix_field(partial_state_updates, 'states') sdf_values, bdf_values = no_update_handler(bdf, sdf) zipped_list = list(zip(sdf_values, bdf_values)) else: sdf_values, bdf_values = only_ep_handler(initial_state) zipped_list = list(zip(sdf_values, bdf_values)) return list(map(lambda x: (x[0] + exo_proc, x[1]), zipped_list))