@@ -137,12 +137,12 @@ def show_layers(self, debug=False, ret=False):
137137 def _build_execution_plan (self , dag ):
138138 """
139139 Create the list of operation-nodes & *instructions* evaluating all
140-
140+
141141 operations & instructions needed a) to free memory and b) avoid
142142 overwritting given intermediate inputs.
143143
144144 :param dag:
145- as shrinked by :meth:`compile()`
145+ the original dag but " shrinked", not "broken"
146146
147147 In the list :class:`DeleteInstructions` steps (DA) are inserted between
148148 operation nodes to reduce the memory footprint of cached results.
@@ -187,45 +187,57 @@ def _build_execution_plan(self, dag):
187187
188188 return plan
189189
190- def _collect_unsatisfiable_operations (self , necessary_nodes , inputs ):
190+ def _collect_unsatisfied_operations (self , dag , inputs ):
191191 """
192- Traverse ordered graph and mark satisfied needs on each operation,
192+ Traverse topologically sorted dag to collect un-satisfied operations.
193+
194+ Unsatisfied operations are those suffering from ANY of the following:
193195
194- collecting those missing at least one.
195- Since the graph is ordered, as soon as we're on an operation,
196- all its needs have been accounted, so we can get its satisfaction.
196+ - They are missing at least one compulsory need-input .
197+ Since the dag is ordered, as soon as we're on an operation,
198+ all its needs have been accounted, so we can get its satisfaction.
197199
198- :param necessary_nodes:
199- the subset of the graph to consider but WITHOUT the initial data
200- (because that is what :meth:`compile()` can gives us...)
200+ - Their provided outputs are not linked to any data in the dag.
201+ An operation might not have any output link when :meth:`_solve_dag()`
202+ has broken them, due to given intermediate inputs.
203+
204+ :param dag:
205+ the graph to consider
201206 :param inputs:
202207 an iterable of the names of the input values
203208 return:
204- a list of unsatisfiable operations
209+ a list of unsatisfied operations to prune
205210 """
206- G = self . graph # shortcut
207- ok_data = set (inputs ) # to collect producible data
208- op_satisfaction = defaultdict ( set ) # to collect operation satisfiable needs
209- unsatisfiables = [] # to collect operations with partial needs
210- # We also need inputs to mark op_satisfaction .
211- nodes = chain ( necessary_nodes , inputs ) # note that `inputs` are plain strings
212- for node in nx .topological_sort (G . subgraph ( nodes ) ):
211+ # To collect data that will be produced.
212+ ok_data = set (inputs )
213+ # To colect the map of operations --> satisfied- needs.
214+ op_satisfaction = defaultdict ( set )
215+ # To collect the operations to drop .
216+ unsatisfied = []
217+ for node in nx .topological_sort (dag ):
213218 if isinstance (node , Operation ):
214- real_needs = set (n for n in node .needs if not isinstance (n , optional ))
215- if real_needs .issubset (op_satisfaction [node ]):
216- # mark all future data-provides as ok
217- ok_data .update (G .adj [node ])
219+ if not dag .adj [node ]:
220+ # Prune operations that ended up providing no output.
221+ unsatisfied .append (node )
218222 else :
219- unsatisfiables .append (node )
223+ real_needs = set (n for n in node .needs
224+ if not isinstance (n , optional ))
225+ if real_needs .issubset (op_satisfaction [node ]):
226+ # We have a satisfied operation; mark its output-data
227+ # as ok.
228+ ok_data .update (dag .adj [node ])
229+ else :
230+ # Prune operations with partial inputs.
231+ unsatisfied .append (node )
220232 elif isinstance (node , (DataPlaceholderNode , str )): # `str` are givens
221233 if node in ok_data :
222234 # mark satisfied-needs on all future operations
223- for future_op in G .adj [node ]:
235+ for future_op in dag .adj [node ]:
224236 op_satisfaction [future_op ].add (node )
225237 else :
226238 raise AssertionError ("Unrecognized network graph node %r" % node )
227239
228- return unsatisfiables
240+ return unsatisfied
229241
230242
231243 def _solve_dag (self , outputs , inputs ):
@@ -245,68 +257,64 @@ def _solve_dag(self, outputs, inputs):
245257 The inputs names of all given inputs.
246258
247259 :return:
248- the subgraph comprising the solution
249-
260+ the *execution plan*
250261 """
251- graph = self .graph
252- if not outputs :
262+ dag = self .graph
253263
254- # If caller requested all outputs, the necessary nodes are all
255- # nodes that are reachable from one of the inputs. Ignore input
256- # names that aren't in the graph.
257- necessary_nodes = set () # unordered, not iterated
258- for input_name in iter (inputs ):
259- if graph .has_node (input_name ):
260- necessary_nodes |= nx .descendants (graph , input_name )
264+ # Ignore input names that aren't in the graph.
265+ graph_inputs = iset (dag .nodes ) & inputs # preserve order
261266
262- else :
267+ # Scream if some requested outputs aren't in the graph.
268+ unknown_outputs = iset (outputs ) - dag .nodes
269+ if unknown_outputs :
270+ raise ValueError (
271+ "Unknown output node(s) requested: %s"
272+ % ", " .join (unknown_outputs ))
273+
274+ broken_dag = dag .copy () # preserve net's graph
263275
264- # If the caller requested a subset of outputs, find any nodes that
265- # are made unecessary because we were provided with an input that's
266- # deeper into the network graph. Ignore input names that aren't
267- # in the graph.
268- unnecessary_nodes = set () # unordered, not iterated
269- for input_name in iter (inputs ):
270- if graph .has_node (input_name ):
271- unnecessary_nodes |= nx .ancestors (graph , input_name )
272-
273- # Find the nodes we need to be able to compute the requested
274- # outputs. Raise an exception if a requested output doesn't
275- # exist in the graph.
276- necessary_nodes = set () # unordered, not iterated
277- for output_name in outputs :
278- if not graph .has_node (output_name ):
279- raise ValueError ("graphkit graph does not have an output "
280- "node named %s" % output_name )
281- necessary_nodes |= nx .ancestors (graph , output_name )
282-
283- # Get rid of the unnecessary nodes from the set of necessary ones.
284- necessary_nodes -= unnecessary_nodes
285-
286- # Drop (un-satifiable) operations with partial inputs.
276+ # Break the incoming edges to all given inputs.
277+ #
278+ # Nodes producing any given intermediate inputs are unecessary
279+ # (unless they are also used elsewhere).
280+ # To discover which ones to prune, we break their incoming edges
281+ # and they will drop out while collecting ancestors from the outputs.
282+ for given in graph_inputs :
283+ broken_dag .remove_edges_from (list (broken_dag .in_edges (given )))
284+
285+ if outputs :
286+ # If caller requested specific outputs, we can prune any
287+ # unrelated nodes further up the dag.
288+ ending_in_outputs = set ()
289+ for input_name in outputs :
290+ ending_in_outputs .update (nx .ancestors (dag , input_name ))
291+ broken_dag = broken_dag .subgraph (ending_in_outputs | set (outputs ))
292+
293+
294+ # Prune (un-satifiable) operations with partial inputs.
287295 # See yahoo/graphkit#18
288296 #
289- unsatisfiables = self ._collect_unsatisfiable_operations ( necessary_nodes , inputs )
290- necessary_nodes -= set ( unsatisfiables )
297+ unsatisfied = self ._collect_unsatisfied_operations ( broken_dag , inputs )
298+ shrinked_dag = dag . subgraph ( broken_dag . nodes - unsatisfied )
291299
292- shrinked_dag = graph . subgraph ( necessary_nodes )
300+ plan = self . _build_execution_plan ( shrinked_dag )
293301
294- return shrinked_dag
302+ return plan
295303
296304
297305 def compile (self , outputs = (), inputs = ()):
298306 """
299- Solve dag, set the :attr:`execution_plan` and cache it.
307+ Solve dag, set the :attr:`execution_plan`, and cache it.
300308
301- See :meth:`_solve_dag()` for description
309+ See :meth:`_solve_dag()` for detailed description.
302310
303311 :param iterable outputs:
304312 A list of desired output names. This can also be ``None``, in which
305313 case the necessary steps are all graph nodes that are reachable
306314 from one of the provided inputs.
307315
308316 :param dict inputs:
309- The inputs names of all given inputs.
317+ The input names of all given inputs.
310318 """
311319
312320 # return steps if it has already been computed before for this set of inputs and outputs
@@ -317,8 +325,7 @@ def compile(self, outputs=(), inputs=()):
317325 if cache_key in self ._cached_execution_plans :
318326 self .execution_plan = self ._cached_execution_plans [cache_key ]
319327 else :
320- dag = self ._solve_dag (outputs , inputs )
321- plan = self ._build_execution_plan (dag )
328+ plan = self ._solve_dag (outputs , inputs )
322329 # save this result in a precomputed cache for future lookup
323330 self .execution_plan = self ._cached_execution_plans [cache_key ] = plan
324331
@@ -338,16 +345,21 @@ def compute(self, outputs, named_inputs, method=None):
338345 and the values are the concrete values you
339346 want to set for the data node.
340347
348+ :param method:
349+ if ``"parallel"``, launches multi-threading.
350+ Set when invoking a composed graph or by
351+ :meth:`~NetworkOperation.set_execution_method()`.
341352
342353 :returns: a dictionary of output data objects, keyed by name.
343354 """
344355
345356 assert isinstance (outputs , (list , tuple )) or outputs is None ,\
346357 "The outputs argument must be a list"
347358
348- # start with fresh data cache
349- cache = {}
350- cache .update (named_inputs )
359+ # start with fresh data cache & overwrites
360+ cache = named_inputs .copy ()
361+
362+ # Build and set :attr:`execution_plan`.
351363 self .compile (outputs , named_inputs .keys ())
352364
353365 # choose a method of execution
0 commit comments