Skip to content

Commit 7c94891

Browse files
committed
WIP/REFACT(net): part 3 of new dag-solver & pin refactoring
+ ALL TCs OK after ~10 fails with #25, #25 bugs, BUT... - WIP: STILL no PIN on PARALLEL. + move check if value in asked outputs before cache-evicting it in build-execution-plan time - compute methods don't need outputs anymore. + test: speed up parallel/multihtread TCs by reducing delays & repetitions. + refact: network rightfully adopted stray functions for parallel processing - they all worke on the net.graph, + upd: networkx api by indexing on `dag.nodes` views. + enh: add log message when deleting in parallel (in par with sequential code). + refact: var-renames, if-then-else simplifications, pythonisms.
1 parent 0dc1293 commit 7c94891

File tree

2 files changed

+102
-85
lines changed

2 files changed

+102
-85
lines changed

graphkit/network.py

Lines changed: 75 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def add_op(self, operation):
124124
assert operation.provides is not None, "Operation's 'provides' must be named"
125125

126126
# assert layer is only added once to graph
127-
assert operation not in self.graph.nodes(), "Operation may only be added once"
127+
assert operation not in self.graph.nodes, "Operation may only be added once"
128128

129129
## Invalidate old plans.
130130
self._cached_execution_plans = {}
@@ -152,7 +152,7 @@ def show_layers(self, debug=False, ret=False):
152152
else:
153153
print(s)
154154

155-
def _build_execution_plan(self, dag, inputs):
155+
def _build_execution_plan(self, dag, inputs, outputs):
156156
"""
157157
Create the list of operation-nodes & *instructions* evaluating all
158158
@@ -161,6 +161,8 @@ def _build_execution_plan(self, dag, inputs):
161161
162162
:param dag:
163163
The original dag, pruned; not broken.
164+
:param outputs:
165+
outp-names to decide whether to add (and which) del-instructions
164166
165167
In the list :class:`DeleteInstructions` steps (DA) are inserted between
166168
operation nodes to reduce the memory footprint of cached results.
@@ -187,22 +189,29 @@ def _build_execution_plan(self, dag, inputs):
187189
plan.append(PinInstruction(node))
188190

189191
elif isinstance(node, Operation):
190-
191192
plan.append(node)
192193

194+
# Keep all values in cache if not specific outputs asked.
195+
if not outputs:
196+
continue
197+
193198
# Add instructions to delete predecessors as possible. A
194199
# predecessor may be deleted if it is a data placeholder that
195200
# is no longer needed by future Operations.
196201
for need in self.graph.pred[node]:
197202
if self._debug:
198203
print("checking if node %s can be deleted" % need)
199204
for future_node in ordered_nodes[i+1:]:
200-
if isinstance(future_node, Operation) and need in future_node.needs:
205+
if (
206+
isinstance(future_node, Operation)
207+
and need in future_node.needs
208+
):
201209
break
202210
else:
203-
if self._debug:
204-
print(" adding delete instruction for %s" % need)
205-
plan.append(DeleteInstruction(need))
211+
if need not in outputs:
212+
if self._debug:
213+
print(" adding delete instruction for %s" % need)
214+
plan.append(DeleteInstruction(need))
206215

207216
else:
208217
raise AssertionError("Unrecognized network graph node %r" % node)
@@ -317,7 +326,7 @@ def _solve_dag(self, outputs, inputs):
317326
unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs)
318327
pruned_dag = dag.subgraph(broken_dag.nodes - unsatisfied)
319328

320-
plan = self._build_execution_plan(pruned_dag, inputs)
329+
plan = self._build_execution_plan(pruned_dag, inputs, outputs)
321330

322331
return plan
323332

@@ -393,7 +402,7 @@ def compute(
393402
cache, overwrites_collector, named_inputs)
394403
else:
395404
self._compute_sequential_method(
396-
cache, overwrites_collector, named_inputs, outputs)
405+
cache, overwrites_collector, named_inputs)
397406

398407
if not outputs:
399408
# Return the whole cache as output, including input and
@@ -432,7 +441,7 @@ def _compute_thread_pool_barrier_method(
432441

433442

434443
# this keeps track of all nodes that have already executed
435-
has_executed = set() # unordered, not iterated
444+
executed_nodes = set() # unordered, not iterated
436445

437446
# with each loop iteration, we determine a set of operations that can be
438447
# scheduled, then schedule them onto a thread pool, then collect their
@@ -443,20 +452,22 @@ def _compute_thread_pool_barrier_method(
443452
# in the current round of scheduling
444453
upnext = []
445454
for node in self.execution_plan:
446-
# only delete if all successors for the data node have been executed
447455
if isinstance(node, DeleteInstruction):
448-
if ready_to_delete_data_node(node,
449-
has_executed,
450-
self.graph):
451-
if node in cache:
452-
cache.pop(node)
453-
454-
# continue if this node is anything but an operation node
455-
if not isinstance(node, Operation):
456-
continue
457-
458-
if ready_to_schedule_operation(node, has_executed, self.graph) \
459-
and node not in has_executed:
456+
# Only delete if all successors for the data node
457+
# have been executed
458+
# Cache value for an optional may be missing.
459+
if (
460+
node in cache
461+
and self._can_delete_data_node(node, executed_nodes)
462+
):
463+
if self._debug:
464+
print("removing data '%s' from cache." % node)
465+
del cache[node]
466+
elif (
467+
isinstance(node, Operation)
468+
and self._can_schedule_operation(node, executed_nodes)
469+
and node not in executed_nodes
470+
):
460471
upnext.append(node)
461472

462473

@@ -469,10 +480,10 @@ def _compute_thread_pool_barrier_method(
469480
upnext)
470481
for op, result in done_iterator:
471482
cache.update(result)
472-
has_executed.add(op)
483+
executed_nodes.add(op)
473484

474485

475-
def _compute_sequential_method(self, cache, overwrites, inputs, outputs):
486+
def _compute_sequential_method(self, cache, overwrites, inputs):
476487
"""
477488
This method runs the graph one operation at a time in a single thread
478489
"""
@@ -500,18 +511,12 @@ def _compute_sequential_method(self, cache, overwrites, inputs, outputs):
500511
if self._debug:
501512
print("step completion time: %s" % t_complete)
502513

503-
# Process DeleteInstructions by deleting the corresponding data
504-
# if possible.
505514
elif isinstance(step, DeleteInstruction):
506-
507-
if outputs and step not in outputs:
508-
# Some DeleteInstruction steps may not exist in the cache
509-
# if they come from optional() needs that are not privoded
510-
# as inputs. Make sure the step exists before deleting.
511-
if step in cache:
512-
if self._debug:
513-
print("removing data '%s' from cache." % step)
514-
cache.pop(step)
515+
# Cache value may be missing if it is optional.
516+
if step in cache:
517+
if self._debug:
518+
print("removing data '%s' from cache." % step)
519+
del cache[step]
515520

516521
elif isinstance(step, PinInstruction):
517522
self._pin_data_in_cache(step, cache, inputs, overwrites)
@@ -550,7 +555,7 @@ def get_node_name(a):
550555
g = pydot.Dot(graph_type="digraph")
551556

552557
# draw nodes
553-
for nx_node in self.graph.nodes():
558+
for nx_node in self.graph.nodes:
554559
if isinstance(nx_node, DataPlaceholderNode):
555560
node = pydot.Node(name=nx_node, shape="rect")
556561
else:
@@ -592,50 +597,45 @@ def get_node_name(a):
592597
return g
593598

594599

595-
def ready_to_schedule_operation(op, has_executed, graph):
596-
"""
597-
Determines if a Operation is ready to be scheduled for execution based on
598-
what has already been executed.
600+
def _can_schedule_operation(self, op, executed_nodes):
601+
"""
602+
Determines if a Operation is ready to be scheduled for execution
603+
604+
based on what has already been executed.
599605
600-
Args:
601-
op:
606+
:param op:
602607
The Operation object to check
603-
has_executed: set
608+
:param set executed_nodes
604609
A set containing all operations that have been executed so far
605-
graph:
606-
The networkx graph containing the operations and data nodes
607-
Returns:
608-
A boolean indicating whether the operation may be scheduled for
609-
execution based on what has already been executed.
610-
"""
611-
# unordered, not iterated
612-
dependencies = set(filter(lambda v: isinstance(v, Operation),
613-
nx.ancestors(graph, op)))
614-
return dependencies.issubset(has_executed)
610+
:return:
611+
A boolean indicating whether the operation may be scheduled for
612+
execution based on what has already been executed.
613+
"""
614+
# unordered, not iterated
615+
dependencies = set(n for n in nx.ancestors(self.graph, op)
616+
if isinstance(n, Operation))
617+
return dependencies.issubset(executed_nodes)
615618

616-
def ready_to_delete_data_node(name, has_executed, graph):
617-
"""
618-
Determines if a DataPlaceholderNode is ready to be deleted from the
619-
cache.
620619

621-
Args:
622-
name:
620+
def _can_delete_data_node(self, name, executed_nodes):
621+
"""
622+
Determines if a DataPlaceholderNode is ready to be deleted from cache.
623+
624+
:param name:
623625
The name of the data node to check
624-
has_executed: set
626+
:param executed_nodes: set
625627
A set containing all operations that have been executed so far
626-
graph:
627-
The networkx graph containing the operations and data nodes
628-
Returns:
629-
A boolean indicating whether the data node can be deleted or not.
630-
"""
631-
data_node = get_data_node(name, graph)
632-
return set(graph.successors(data_node)).issubset(has_executed)
628+
:return:
629+
A boolean indicating whether the data node can be deleted or not.
630+
"""
631+
data_node = self.get_data_node(name)
632+
return data_node and set(
633+
self.graph.successors(data_node)).issubset(executed_nodes)
633634

634-
def get_data_node(name, graph):
635-
"""
636-
Gets a data node from a graph using its name
637-
"""
638-
for node in graph.nodes():
639-
if node == name and isinstance(node, DataPlaceholderNode):
635+
def get_data_node(self, name):
636+
"""
637+
Retuen the data node from a graph using its name, or None.
638+
"""
639+
node = self.graph.nodes[name]
640+
if isinstance(node, DataPlaceholderNode):
640641
return node
641-
return None

test/test_graphkit.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import graphkit.network as network
1212
import graphkit.modifiers as modifiers
1313
from graphkit import operation, compose, Operation
14+
from graphkit.network import DeleteInstruction
1415

1516

1617
def scream(*args, **kwargs):
@@ -226,7 +227,7 @@ def test_pruning_not_overrides_given_intermediate():
226227
netop.set_overwrites_collector(overwrites)
227228
assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked")
228229
assert overwrites == {} # unjust must have been pruned
229-
230+
230231
overwrites = {}
231232
netop.set_overwrites_collector(overwrites)
232233
assert netop({"a": 5, "overriden": 1, "c": 2}) == exp
@@ -413,6 +414,9 @@ def addplusplus(a, b, c=0):
413414

414415
def test_deleteinstructs_vary_with_inputs():
415416
# Check #21: DeleteInstructions positions vary when inputs change.
417+
def count_deletions(steps):
418+
return sum(isinstance(n, DeleteInstruction) for n in steps)
419+
416420
netop = compose(name="netop")(
417421
operation(name="a free without b", needs=["a"], provides=["aa"])(identity),
418422
operation(name="satisfiable", needs=["a", "b"], provides=["ab"])(add),
@@ -438,27 +442,40 @@ def test_deleteinstructs_vary_with_inputs():
438442
assert res == filtdict(exp, "asked") # ok
439443
steps22 = netop.net.execution_plan
440444

441-
assert steps11 == steps12
442-
assert steps21 == steps22
443-
assert steps11 != steps21 # FAILs in v1.2.4 + #18
444-
assert steps12 != steps22 # FAILs in v1.2.4 + #18
445+
# When no outs, no del-instructs.
446+
assert steps11 != steps12
447+
assert count_deletions(steps11) == 0
448+
assert steps21 != steps22
449+
assert count_deletions(steps21) == 0
450+
451+
# Check steps vary with inputs
452+
#
453+
# FAILs in v1.2.4 + #18, PASS in #26
454+
assert steps11 != steps21
455+
456+
# Check deletes vary with inputs
457+
#
458+
# FAILs in v1.2.4 + #18, PASS in #26
459+
assert count_deletions(steps12) != count_deletions(steps22)
445460

446461

447462
def test_parallel_execution():
448463
import time
449464

465+
delay = 0.5
466+
450467
def fn(x):
451-
time.sleep(1)
468+
time.sleep(delay)
452469
print("fn %s" % (time.time() - t0))
453470
return 1 + x
454471

455472
def fn2(a,b):
456-
time.sleep(1)
473+
time.sleep(delay)
457474
print("fn2 %s" % (time.time() - t0))
458475
return a+b
459476

460477
def fn3(z, k=1):
461-
time.sleep(1)
478+
time.sleep(delay)
462479
print("fn3 %s" % (time.time() - t0))
463480
return z + k
464481

@@ -527,8 +544,8 @@ def infer(i):
527544
assert tuple(sorted(results.keys())) == tuple(sorted(outputs)), (outputs, results)
528545
return results
529546

530-
N = 100
531-
for i in range(20, 200):
547+
N = 33
548+
for i in range(13, 61):
532549
pool = Pool(i)
533550
pool.map(infer, range(N))
534551
pool.close()

0 commit comments

Comments
 (0)