@@ -195,6 +195,9 @@ def _reset_node_attributes(self):
195195 self .dag .graph .nodes [node ]["noise_dist" ] = (
196196 self .scm_params .node_noise_dist_choices .sample_uniform ()
197197 )
198+ self .dag .graph .nodes [node ]["propagation_agg" ] = (
199+ self .scm_params .propagation_agg_choices .sample_uniform ()
200+ )
198201 self .dag .graph .nodes [node ]["decoder" ] = self .get_decoder (
199202 _stype = _stype , num_categories = num_categories
200203 )
@@ -213,6 +216,22 @@ def _reset_edge_attributes(self):
213216 _stype = parent_node_stype , num_categories = parent_node_num_categories
214217 )
215218
219+ def _aggregate_embeddings (
220+ self , embs : list [torch .Tensor ], weights : list [float ], mode : str
221+ ) -> torch .Tensor :
222+ weighted = [w * e for w , e in zip (weights , embs )]
223+ stack = torch .stack (weighted , dim = 0 ) # (n, emb_dim)
224+ if mode == "sum" :
225+ return stack .sum (dim = 0 )
226+ elif mode == "max" :
227+ return stack .max (dim = 0 ).values
228+ elif mode == "product" :
229+ return stack .prod (dim = 0 )
230+ elif mode == "logexp" :
231+ return torch .logsumexp (stack , dim = 0 )
232+ else :
233+ raise ValueError (f"Unknown aggregation mode: { mode } " )
234+
216235 def propagate (self , row_idx : int , foreign_row_idxs : list [int ], foreign_scms : list [SCM ]):
217236 foreign_scms_row_embds : list [list ] = []
218237 for foreign_row_idx , foreign_scm in zip (foreign_row_idxs , foreign_scms ):
@@ -234,26 +253,35 @@ def propagate(self, row_idx: int, foreign_row_idxs: list[int], foreign_scms: lis
234253 value = torch .Tensor ([value ])
235254 self .dag .graph .nodes [node ]["value" ] = value
236255 else :
237- parent_nodes = self .dag .graph .predecessors (node )
256+ parent_nodes = list ( self .dag .graph .predecessors (node ) )
238257 node_num_categories = self .dag .graph .nodes [node ]["num_categories" ]
258+ propagation_agg = self .dag .graph .nodes [node ]["propagation_agg" ]
239259
240260 # directly add noise
241261 noise_dist = self .dag .graph .nodes [node ]["noise_dist" ]
242262 node_emb = (
243263 noise_dist .sample (sample_shape = (self .scm_params .mlp_emb_dim ,)).squeeze ()
244264 / self .scm_params .mlp_emb_dim
245265 )
266+
267+ all_embs , all_weights = [], []
246268 for parent_node in parent_nodes :
247269 parent_attrs = self .dag .graph .nodes [parent_node ]
248270 encoder = self .dag .graph .edges [parent_node , node ]["encoder" ]
249- parent_emb = encoder (parent_attrs ["value" ]).squeeze ()
250- weight = self . dag . graph . get_edge_data ( parent_node , node )[ "weight" ]
251- node_emb += weight * parent_emb
252-
271+ all_embs . append ( encoder (parent_attrs ["value" ]).squeeze () )
272+ all_weights . append (
273+ self . dag . graph . get_edge_data ( parent_node , node )[ "weight" ]
274+ )
253275 for foreign_row_embds in foreign_scms_row_embds :
276+ w = 1 / len (foreign_row_embds ) if propagation_agg == "sum" else 1.0
254277 for foreign_row_embd in foreign_row_embds :
255- weight = 1 / len (foreign_row_embds )
256- node_emb += weight * foreign_row_embd
278+ all_embs .append (foreign_row_embd )
279+ all_weights .append (w )
280+
281+ if all_embs :
282+ node_emb = node_emb + self ._aggregate_embeddings (
283+ all_embs , all_weights , propagation_agg
284+ )
257285
258286 decoder = self .dag .graph .nodes [node ]["decoder" ]
259287 value = decoder (node_emb )
0 commit comments