@@ -137,12 +137,12 @@ def show_layers(self, debug=False, ret=False):
137137 def _build_execution_plan (self , dag ):
138138 """
139139 Create the list of operation-nodes & *instructions* evaluating all
140-
140+
141141 operations & instructions needed a) to free memory and b) avoid
142142 overwritting given intermediate inputs.
143143
144144 :param dag:
145- as shrinked by :meth:`compile()`
145+ the original dag but " shrinked", not "broken"
146146
147147 In the list :class:`DeleteInstructions` steps (DA) are inserted between
148148 operation nodes to reduce the memory footprint of cached results.
@@ -187,45 +187,52 @@ def _build_execution_plan(self, dag):
187187
188188 return plan
189189
190- def _collect_unsatisfiable_operations (self , necessary_nodes , inputs ):
190+ def _collect_unsatisfied_operations (self , dag , inputs ):
191191 """
192- Traverse ordered graph and mark satisfied needs on each operation,
192+ Traverse topologically sorted dag to collect un-satisfied operations.
193+
194+ Unsatisfied operations are those suffering from ANY of the following:
195+
196+ - They are missing at least one compulsory need-input.
197+ Since the dag is ordered, as soon as we're on an operation,
198+ all its needs have been accounted, so we can get its satisfaction.
193199
194- collecting those missing at least one .
195- Since the graph is ordered, as soon as we're on an operation,
196- all its needs have been accounted, so we can get its satisfaction .
200+ - Their provided outputs are not linked to any data in the dag .
201+ An operation might not have any output link when :meth:`_solve_dag()`
202+ has broken them, due to given intermediate inputs .
197203
198- :param necessary_nodes:
199- the subset of the graph to consider but WITHOUT the initial data
200- (because that is what :meth:`compile()` can gives us...)
204+ :param dag:
205+ the graph to consider
201206 :param inputs:
202207 an iterable of the names of the input values
203208 return:
204- a list of unsatisfiable operations
209+ a list of unsatisfied operations to prune
205210 """
206- G = self . graph # shortcut
207- ok_data = set (inputs ) # to collect producible data
208- op_satisfaction = defaultdict ( set ) # to collect operation satisfiable needs
209- unsatisfiables = [] # to collect operations with partial needs
210- # We also need inputs to mark op_satisfaction .
211- nodes = chain ( necessary_nodes , inputs ) # note that `inputs` are plain strings
212- for node in nx .topological_sort (G . subgraph ( nodes ) ):
211+ # To collect data that will be produced.
212+ ok_data = set (inputs )
213+ # To colect the map of operations --> satisfied- needs.
214+ op_satisfaction = defaultdict ( set )
215+ # To collect the operations to drop .
216+ unsatisfied = []
217+ for node in nx .topological_sort (dag ):
213218 if isinstance (node , Operation ):
214219 real_needs = set (n for n in node .needs if not isinstance (n , optional ))
215220 if real_needs .issubset (op_satisfaction [node ]):
216- # mark all future data-provides as ok
217- ok_data .update (G .adj [node ])
221+ # We have a satisfied operation; mark its output-data
222+ # as ok.
223+ ok_data .update (dag .adj [node ])
218224 else :
219- unsatisfiables .append (node )
225+ # Prune operations with partial inputs.
226+ unsatisfied .append (node )
220227 elif isinstance (node , (DataPlaceholderNode , str )): # `str` are givens
221228 if node in ok_data :
222229 # mark satisfied-needs on all future operations
223- for future_op in G .adj [node ]:
230+ for future_op in dag .adj [node ]:
224231 op_satisfaction [future_op ].add (node )
225232 else :
226233 raise AssertionError ("Unrecognized network graph node %r" % node )
227234
228- return unsatisfiables
235+ return unsatisfied
229236
230237
231238 def _solve_dag (self , outputs , inputs ):
@@ -245,68 +252,64 @@ def _solve_dag(self, outputs, inputs):
245252 The inputs names of all given inputs.
246253
247254 :return:
248- the subgraph comprising the solution
249-
255+ the *execution plan*
250256 """
251- graph = self .graph
252- if not outputs :
257+ dag = self .graph
253258
254- # If caller requested all outputs, the necessary nodes are all
255- # nodes that are reachable from one of the inputs. Ignore input
256- # names that aren't in the graph.
257- necessary_nodes = set () # unordered, not iterated
258- for input_name in iter (inputs ):
259- if graph .has_node (input_name ):
260- necessary_nodes |= nx .descendants (graph , input_name )
259+ # Ignore input names that aren't in the graph.
260+ graph_inputs = iset (dag .nodes ) & inputs # preserve order
261261
262- else :
262+ # Scream if some requested outputs aren't in the graph.
263+ unknown_outputs = iset (outputs ) - dag .nodes
264+ if unknown_outputs :
265+ raise ValueError (
266+ "Unknown output node(s) requested: %s"
267+ % ", " .join (unknown_outputs ))
268+
269+ broken_dag = dag .copy () # preserve net's graph
263270
264- # If the caller requested a subset of outputs, find any nodes that
265- # are made unecessary because we were provided with an input that's
266- # deeper into the network graph. Ignore input names that aren't
267- # in the graph.
268- unnecessary_nodes = set () # unordered, not iterated
269- for input_name in iter (inputs ):
270- if graph .has_node (input_name ):
271- unnecessary_nodes |= nx .ancestors (graph , input_name )
272-
273- # Find the nodes we need to be able to compute the requested
274- # outputs. Raise an exception if a requested output doesn't
275- # exist in the graph.
276- necessary_nodes = set () # unordered, not iterated
277- for output_name in outputs :
278- if not graph .has_node (output_name ):
279- raise ValueError ("graphkit graph does not have an output "
280- "node named %s" % output_name )
281- necessary_nodes |= nx .ancestors (graph , output_name )
282-
283- # Get rid of the unnecessary nodes from the set of necessary ones.
284- necessary_nodes -= unnecessary_nodes
285-
286- # Drop (un-satifiable) operations with partial inputs.
271+ # Break the incoming edges to all given inputs.
272+ #
273+ # Nodes producing any given intermediate inputs are unecessary
274+ # (unless they are also used elsewhere).
275+ # To discover which ones to prune, we break their incoming edges
276+ # and they will drop out while collecting ancestors from the outputs.
277+ for given in graph_inputs :
278+ broken_dag .remove_edges_from (list (broken_dag .in_edges (given )))
279+
280+ if outputs :
281+ # If caller requested specific outputs, we can prune any
282+ # unrelated nodes further up the dag.
283+ ending_in_outputs = set ()
284+ for input_name in outputs :
285+ ending_in_outputs .update (nx .ancestors (dag , input_name ))
286+ broken_dag = broken_dag .subgraph (ending_in_outputs | set (outputs ))
287+
288+
289+ # Prune (un-satifiable) operations with partial inputs.
287290 # See yahoo/graphkit#18
288291 #
289- unsatisfiables = self ._collect_unsatisfiable_operations ( necessary_nodes , inputs )
290- necessary_nodes -= set ( unsatisfiables )
292+ unsatisfied = self ._collect_unsatisfied_operations ( broken_dag , inputs )
293+ shrinked_dag = dag . subgraph ( broken_dag . nodes - unsatisfied )
291294
292- shrinked_dag = graph . subgraph ( necessary_nodes )
295+ plan = self . _build_execution_plan ( shrinked_dag )
293296
294- return shrinked_dag
297+ return plan
295298
296299
297300 def compile (self , outputs = (), inputs = ()):
298301 """
299- Solve dag, set the :attr:`execution_plan` and cache it.
302+ Solve dag, set the :attr:`execution_plan`, and cache it.
300303
301- See :meth:`_solve_dag()` for description
304+ See :meth:`_solve_dag()` for detailed description.
302305
303306 :param iterable outputs:
304307 A list of desired output names. This can also be ``None``, in which
305308 case the necessary steps are all graph nodes that are reachable
306309 from one of the provided inputs.
307310
308311 :param dict inputs:
309- The inputs names of all given inputs.
312+ The input names of all given inputs.
310313 """
311314
312315 # return steps if it has already been computed before for this set of inputs and outputs
@@ -317,8 +320,7 @@ def compile(self, outputs=(), inputs=()):
317320 if cache_key in self ._cached_execution_plans :
318321 self .execution_plan = self ._cached_execution_plans [cache_key ]
319322 else :
320- dag = self ._solve_dag (outputs , inputs )
321- plan = self ._build_execution_plan (dag )
323+ plan = self ._solve_dag (outputs , inputs )
322324 # save this result in a precomputed cache for future lookup
323325 self .execution_plan = self ._cached_execution_plans [cache_key ] = plan
324326
@@ -345,9 +347,10 @@ def compute(self, outputs, named_inputs, method=None):
345347 assert isinstance (outputs , (list , tuple )) or outputs is None ,\
346348 "The outputs argument must be a list"
347349
348- # start with fresh data cache
349- cache = {}
350- cache .update (named_inputs )
350+ # start with fresh data cache & overwrites
351+ cache = named_inputs .copy ()
352+
353+ # Build and set :attr:`execution_plan`.
351354 self .compile (outputs , named_inputs .keys ())
352355
353356 # choose a method of execution
0 commit comments