Skip to content

Commit bd308eb

Browse files
committed
add propagation aggregate choices
1 parent be6f396 commit bd308eb

2 files changed

Lines changed: 40 additions & 7 deletions

File tree

plurel/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ class SCMParams:
176176
ts_ar_rho_choices: Choices = Choices(kind="range", value=[0.0, 0.9])
177177
ts_value_scale_choices: Choices = Choices(kind="set", value=[0.01, 0.1, 1, 10, 100])
178178

179+
propagation_agg_choices: Choices = Choices(
180+
kind="set",
181+
value=["sum", "max", "product", "logexp"],
182+
)
183+
179184
mlp_in_dim: int = 1
180185
mlp_out_dim: int = 1
181186
mlp_emb_dim: int = 32

plurel/scm.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def _reset_node_attributes(self):
195195
self.dag.graph.nodes[node]["noise_dist"] = (
196196
self.scm_params.node_noise_dist_choices.sample_uniform()
197197
)
198+
self.dag.graph.nodes[node]["propagation_agg"] = (
199+
self.scm_params.propagation_agg_choices.sample_uniform()
200+
)
198201
self.dag.graph.nodes[node]["decoder"] = self.get_decoder(
199202
_stype=_stype, num_categories=num_categories
200203
)
@@ -213,6 +216,22 @@ def _reset_edge_attributes(self):
213216
_stype=parent_node_stype, num_categories=parent_node_num_categories
214217
)
215218

219+
def _aggregate_embeddings(
220+
self, embs: list[torch.Tensor], weights: list[float], mode: str
221+
) -> torch.Tensor:
222+
weighted = [w * e for w, e in zip(weights, embs)]
223+
stack = torch.stack(weighted, dim=0) # (n, emb_dim)
224+
if mode == "sum":
225+
return stack.sum(dim=0)
226+
elif mode == "max":
227+
return stack.max(dim=0).values
228+
elif mode == "product":
229+
return stack.prod(dim=0)
230+
elif mode == "logexp":
231+
return torch.logsumexp(stack, dim=0)
232+
else:
233+
raise ValueError(f"Unknown aggregation mode: {mode}")
234+
216235
def propagate(self, row_idx: int, foreign_row_idxs: list[int], foreign_scms: list[SCM]):
217236
foreign_scms_row_embds: list[list] = []
218237
for foreign_row_idx, foreign_scm in zip(foreign_row_idxs, foreign_scms):
@@ -234,26 +253,35 @@ def propagate(self, row_idx: int, foreign_row_idxs: list[int], foreign_scms: lis
234253
value = torch.Tensor([value])
235254
self.dag.graph.nodes[node]["value"] = value
236255
else:
237-
parent_nodes = self.dag.graph.predecessors(node)
256+
parent_nodes = list(self.dag.graph.predecessors(node))
238257
node_num_categories = self.dag.graph.nodes[node]["num_categories"]
258+
propagation_agg = self.dag.graph.nodes[node]["propagation_agg"]
239259

240260
# directly add noise
241261
noise_dist = self.dag.graph.nodes[node]["noise_dist"]
242262
node_emb = (
243263
noise_dist.sample(sample_shape=(self.scm_params.mlp_emb_dim,)).squeeze()
244264
/ self.scm_params.mlp_emb_dim
245265
)
266+
267+
all_embs, all_weights = [], []
246268
for parent_node in parent_nodes:
247269
parent_attrs = self.dag.graph.nodes[parent_node]
248270
encoder = self.dag.graph.edges[parent_node, node]["encoder"]
249-
parent_emb = encoder(parent_attrs["value"]).squeeze()
250-
weight = self.dag.graph.get_edge_data(parent_node, node)["weight"]
251-
node_emb += weight * parent_emb
252-
271+
all_embs.append(encoder(parent_attrs["value"]).squeeze())
272+
all_weights.append(
273+
self.dag.graph.get_edge_data(parent_node, node)["weight"]
274+
)
253275
for foreign_row_embds in foreign_scms_row_embds:
276+
w = 1 / len(foreign_row_embds) if propagation_agg == "sum" else 1.0
254277
for foreign_row_embd in foreign_row_embds:
255-
weight = 1 / len(foreign_row_embds)
256-
node_emb += weight * foreign_row_embd
278+
all_embs.append(foreign_row_embd)
279+
all_weights.append(w)
280+
281+
if all_embs:
282+
node_emb = node_emb + self._aggregate_embeddings(
283+
all_embs, all_weights, propagation_agg
284+
)
257285

258286
decoder = self.dag.graph.nodes[node]["decoder"]
259287
value = decoder(node_emb)

0 commit comments

Comments
 (0)