@@ -124,7 +124,7 @@ def add_op(self, operation):
124124 assert operation .provides is not None , "Operation's 'provides' must be named"
125125
126126 # assert layer is only added once to graph
127- assert operation not in self .graph .nodes () , "Operation may only be added once"
127+ assert operation not in self .graph .nodes , "Operation may only be added once"
128128
129129 ## Invalidate old plans.
130130 self ._cached_execution_plans = {}
@@ -152,7 +152,7 @@ def show_layers(self, debug=False, ret=False):
152152 else :
153153 print (s )
154154
155- def _build_execution_plan (self , dag , inputs ):
155+ def _build_execution_plan (self , dag , inputs , outputs ):
156156 """
157157 Create the list of operation-nodes & *instructions* evaluating all
158158
@@ -161,6 +161,8 @@ def _build_execution_plan(self, dag, inputs):
161161
162162 :param dag:
163163 The original dag, pruned; not broken.
164+ :param outputs:
165+ outp-names to decide whether to add (and which) del-instructions
164166
165167 In the list :class:`DeleteInstructions` steps (DA) are inserted between
166168 operation nodes to reduce the memory footprint of cached results.
@@ -187,22 +189,29 @@ def _build_execution_plan(self, dag, inputs):
187189 plan .append (PinInstruction (node ))
188190
189191 elif isinstance (node , Operation ):
190-
191192 plan .append (node )
192193
194+ # Keep all values in cache if not specific outputs asked.
195+ if not outputs :
196+ continue
197+
193198 # Add instructions to delete predecessors as possible. A
194199 # predecessor may be deleted if it is a data placeholder that
195200 # is no longer needed by future Operations.
196201 for need in self .graph .pred [node ]:
197202 if self ._debug :
198203 print ("checking if node %s can be deleted" % need )
199204 for future_node in ordered_nodes [i + 1 :]:
200- if isinstance (future_node , Operation ) and need in future_node .needs :
205+ if (
206+ isinstance (future_node , Operation )
207+ and need in future_node .needs
208+ ):
201209 break
202210 else :
203- if self ._debug :
204- print (" adding delete instruction for %s" % need )
205- plan .append (DeleteInstruction (need ))
211+ if need not in outputs :
212+ if self ._debug :
213+ print (" adding delete instruction for %s" % need )
214+ plan .append (DeleteInstruction (need ))
206215
207216 else :
208217 raise AssertionError ("Unrecognized network graph node %r" % node )
@@ -317,7 +326,7 @@ def _solve_dag(self, outputs, inputs):
317326 unsatisfied = self ._collect_unsatisfied_operations (broken_dag , inputs )
318327 pruned_dag = dag .subgraph (broken_dag .nodes - unsatisfied )
319328
320- plan = self ._build_execution_plan (pruned_dag , inputs )
329+ plan = self ._build_execution_plan (pruned_dag , inputs , outputs )
321330
322331 return plan
323332
@@ -393,7 +402,7 @@ def compute(
393402 cache , overwrites_collector , named_inputs )
394403 else :
395404 self ._compute_sequential_method (
396- cache , overwrites_collector , named_inputs , outputs )
405+ cache , overwrites_collector , named_inputs )
397406
398407 if not outputs :
399408 # Return the whole cache as output, including input and
@@ -432,7 +441,7 @@ def _compute_thread_pool_barrier_method(
432441
433442
434443 # this keeps track of all nodes that have already executed
435- has_executed = set () # unordered, not iterated
444+ executed_nodes = set () # unordered, not iterated
436445
437446 # with each loop iteration, we determine a set of operations that can be
438447 # scheduled, then schedule them onto a thread pool, then collect their
@@ -443,20 +452,22 @@ def _compute_thread_pool_barrier_method(
443452 # in the current round of scheduling
444453 upnext = []
445454 for node in self .execution_plan :
446- # only delete if all successors for the data node have been executed
447455 if isinstance (node , DeleteInstruction ):
448- if ready_to_delete_data_node (node ,
449- has_executed ,
450- self .graph ):
451- if node in cache :
452- cache .pop (node )
453-
454- # continue if this node is anything but an operation node
455- if not isinstance (node , Operation ):
456- continue
457-
458- if ready_to_schedule_operation (node , has_executed , self .graph ) \
459- and node not in has_executed :
456+ # Only delete if all successors for the data node
457+ # have been executed
458+ # Cache value for an optional may be missing.
459+ if (
460+ node in cache
461+ and self ._can_delete_data_node (node , executed_nodes )
462+ ):
463+ if self ._debug :
464+ print ("removing data '%s' from cache." % node )
465+ del cache [node ]
466+ elif (
467+ isinstance (node , Operation )
468+ and self ._can_schedule_operation (node , executed_nodes )
469+ and node not in executed_nodes
470+ ):
460471 upnext .append (node )
461472
462473
@@ -469,10 +480,10 @@ def _compute_thread_pool_barrier_method(
469480 upnext )
470481 for op , result in done_iterator :
471482 cache .update (result )
472- has_executed .add (op )
483+ executed_nodes .add (op )
473484
474485
475- def _compute_sequential_method (self , cache , overwrites , inputs , outputs ):
486+ def _compute_sequential_method (self , cache , overwrites , inputs ):
476487 """
477488 This method runs the graph one operation at a time in a single thread
478489 """
@@ -500,18 +511,12 @@ def _compute_sequential_method(self, cache, overwrites, inputs, outputs):
500511 if self ._debug :
501512 print ("step completion time: %s" % t_complete )
502513
503- # Process DeleteInstructions by deleting the corresponding data
504- # if possible.
505514 elif isinstance (step , DeleteInstruction ):
506-
507- if outputs and step not in outputs :
508- # Some DeleteInstruction steps may not exist in the cache
509- # if they come from optional() needs that are not privoded
510- # as inputs. Make sure the step exists before deleting.
511- if step in cache :
512- if self ._debug :
513- print ("removing data '%s' from cache." % step )
514- cache .pop (step )
515+ # Cache value may be missing if it is optional.
516+ if step in cache :
517+ if self ._debug :
518+ print ("removing data '%s' from cache." % step )
519+ del cache [step ]
515520
516521 elif isinstance (step , PinInstruction ):
517522 self ._pin_data_in_cache (step , cache , inputs , overwrites )
@@ -550,7 +555,7 @@ def get_node_name(a):
550555 g = pydot .Dot (graph_type = "digraph" )
551556
552557 # draw nodes
553- for nx_node in self .graph .nodes () :
558+ for nx_node in self .graph .nodes :
554559 if isinstance (nx_node , DataPlaceholderNode ):
555560 node = pydot .Node (name = nx_node , shape = "rect" )
556561 else :
@@ -592,50 +597,45 @@ def get_node_name(a):
592597 return g
593598
594599
595- def ready_to_schedule_operation (op , has_executed , graph ):
596- """
597- Determines if a Operation is ready to be scheduled for execution based on
598- what has already been executed.
600+ def _can_schedule_operation (self , op , executed_nodes ):
601+ """
602+ Determines if a Operation is ready to be scheduled for execution
603+
604+ based on what has already been executed.
599605
600- Args:
601- op:
606+ :param op:
602607 The Operation object to check
603- has_executed: set
608+ :param set executed_nodes
604609 A set containing all operations that have been executed so far
605- graph:
606- The networkx graph containing the operations and data nodes
607- Returns:
608- A boolean indicating whether the operation may be scheduled for
609- execution based on what has already been executed.
610- """
611- # unordered, not iterated
612- dependencies = set (filter (lambda v : isinstance (v , Operation ),
613- nx .ancestors (graph , op )))
614- return dependencies .issubset (has_executed )
610+ :return:
611+ A boolean indicating whether the operation may be scheduled for
612+ execution based on what has already been executed.
613+ """
614+ # unordered, not iterated
615+ dependencies = set (n for n in nx .ancestors (self .graph , op )
616+ if isinstance (n , Operation ))
617+ return dependencies .issubset (executed_nodes )
615618
616- def ready_to_delete_data_node (name , has_executed , graph ):
617- """
618- Determines if a DataPlaceholderNode is ready to be deleted from the
619- cache.
620619
621- Args:
622- name:
620+ def _can_delete_data_node (self , name , executed_nodes ):
621+ """
622+ Determines if a DataPlaceholderNode is ready to be deleted from cache.
623+
624+ :param name:
623625 The name of the data node to check
624- has_executed : set
626+ :param executed_nodes : set
625627 A set containing all operations that have been executed so far
626- graph:
627- The networkx graph containing the operations and data nodes
628- Returns:
629- A boolean indicating whether the data node can be deleted or not.
630- """
631- data_node = get_data_node (name , graph )
632- return set (graph .successors (data_node )).issubset (has_executed )
628+ :return:
629+ A boolean indicating whether the data node can be deleted or not.
630+ """
631+ data_node = self .get_data_node (name )
632+ return data_node and set (
633+ self .graph .successors (data_node )).issubset (executed_nodes )
633634
634- def get_data_node (name , graph ):
635- """
636- Gets a data node from a graph using its name
637- """
638- for node in graph .nodes ():
639- if node == name and isinstance (node , DataPlaceholderNode ):
635+ def get_data_node (self , name ):
636+ """
637+ Retuen the data node from a graph using its name, or None.
638+ """
639+ node = self . graph .nodes [ name ]
640+ if isinstance (node , DataPlaceholderNode ):
640641 return node
641- return None
0 commit comments