Skip to content

Commit 0830b7c

Browse files
committed
ENH(DAG): NEW SOLVER
+ Pruning behaves correctly also when outputs given; this happens by breaking incoming provide-links to any given intermedediate inputs. + Unsatisfied detection now includes those without outputs due to broken links (above). + Remove some uneeded "glue" from unsatisfied-detection code, leftover from previous compile() refactoring. + Renamed satisfiable --> satisfied. + Improved unknown output requested raise-message. + x3 TCs PASS, x1 in #24 and the first x2 in #25. - 1x TCs in #25 still FAIL, and need "Pinning" of given-inputs (the operation MUST and MUST NOT run in these cases).
1 parent 17eb2fd commit 0830b7c

File tree

2 files changed

+87
-74
lines changed

2 files changed

+87
-74
lines changed

graphkit/network.py

Lines changed: 84 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,12 @@ def show_layers(self, debug=False, ret=False):
137137
def _build_execution_plan(self, dag):
138138
"""
139139
Create the list of operation-nodes & *instructions* evaluating all
140-
140+
141141
operations & instructions needed a) to free memory and b) avoid
142142
overwritting given intermediate inputs.
143143
144144
:param dag:
145-
as shrinked by :meth:`compile()`
145+
the original dag but "shrinked", not "broken"
146146
147147
In the list :class:`DeleteInstructions` steps (DA) are inserted between
148148
operation nodes to reduce the memory footprint of cached results.
@@ -187,45 +187,57 @@ def _build_execution_plan(self, dag):
187187

188188
return plan
189189

190-
def _collect_unsatisfiable_operations(self, necessary_nodes, inputs):
190+
def _collect_unsatisfied_operations(self, dag, inputs):
191191
"""
192-
Traverse ordered graph and mark satisfied needs on each operation,
192+
Traverse topologically sorted dag to collect un-satisfied operations.
193+
194+
Unsatisfied operations are those suffering from ANY of the following:
193195
194-
collecting those missing at least one.
195-
Since the graph is ordered, as soon as we're on an operation,
196-
all its needs have been accounted, so we can get its satisfaction.
196+
- They are missing at least one compulsory need-input.
197+
Since the dag is ordered, as soon as we're on an operation,
198+
all its needs have been accounted, so we can get its satisfaction.
197199
198-
:param necessary_nodes:
199-
the subset of the graph to consider but WITHOUT the initial data
200-
(because that is what :meth:`compile()` can gives us...)
200+
- Their provided outputs are not linked to any data in the dag.
201+
An operation might not have any output link when :meth:`_solve_dag()`
202+
has broken them, due to given intermediate inputs.
203+
204+
:param dag:
205+
the graph to consider
201206
:param inputs:
202207
an iterable of the names of the input values
203208
return:
204-
a list of unsatisfiable operations
209+
a list of unsatisfied operations to prune
205210
"""
206-
G = self.graph # shortcut
207-
ok_data = set(inputs) # to collect producible data
208-
op_satisfaction = defaultdict(set) # to collect operation satisfiable needs
209-
unsatisfiables = [] # to collect operations with partial needs
210-
# We also need inputs to mark op_satisfaction.
211-
nodes = chain(necessary_nodes, inputs) # note that `inputs` are plain strings
212-
for node in nx.topological_sort(G.subgraph(nodes)):
211+
# To collect data that will be produced.
212+
ok_data = set(inputs)
213+
# To colect the map of operations --> satisfied-needs.
214+
op_satisfaction = defaultdict(set)
215+
# To collect the operations to drop.
216+
unsatisfied = []
217+
for node in nx.topological_sort(dag):
213218
if isinstance(node, Operation):
214-
real_needs = set(n for n in node.needs if not isinstance(n, optional))
215-
if real_needs.issubset(op_satisfaction[node]):
216-
# mark all future data-provides as ok
217-
ok_data.update(G.adj[node])
219+
if not dag.adj[node]:
220+
# Prune operations that ended up providing no output.
221+
unsatisfied.append(node)
218222
else:
219-
unsatisfiables.append(node)
223+
real_needs = set(n for n in node.needs
224+
if not isinstance(n, optional))
225+
if real_needs.issubset(op_satisfaction[node]):
226+
# We have a satisfied operation; mark its output-data
227+
# as ok.
228+
ok_data.update(dag.adj[node])
229+
else:
230+
# Prune operations with partial inputs.
231+
unsatisfied.append(node)
220232
elif isinstance(node, (DataPlaceholderNode, str)): # `str` are givens
221233
if node in ok_data:
222234
# mark satisfied-needs on all future operations
223-
for future_op in G.adj[node]:
235+
for future_op in dag.adj[node]:
224236
op_satisfaction[future_op].add(node)
225237
else:
226238
raise AssertionError("Unrecognized network graph node %r" % node)
227239

228-
return unsatisfiables
240+
return unsatisfied
229241

230242

231243
def _solve_dag(self, outputs, inputs):
@@ -245,68 +257,64 @@ def _solve_dag(self, outputs, inputs):
245257
The inputs names of all given inputs.
246258
247259
:return:
248-
the subgraph comprising the solution
249-
260+
the *execution plan*
250261
"""
251-
graph = self.graph
252-
if not outputs:
262+
dag = self.graph
253263

254-
# If caller requested all outputs, the necessary nodes are all
255-
# nodes that are reachable from one of the inputs. Ignore input
256-
# names that aren't in the graph.
257-
necessary_nodes = set() # unordered, not iterated
258-
for input_name in iter(inputs):
259-
if graph.has_node(input_name):
260-
necessary_nodes |= nx.descendants(graph, input_name)
264+
# Ignore input names that aren't in the graph.
265+
graph_inputs = iset(dag.nodes) & inputs # preserve order
261266

262-
else:
267+
# Scream if some requested outputs aren't in the graph.
268+
unknown_outputs = iset(outputs) - dag.nodes
269+
if unknown_outputs:
270+
raise ValueError(
271+
"Unknown output node(s) requested: %s"
272+
% ", ".join(unknown_outputs))
273+
274+
broken_dag = dag.copy() # preserve net's graph
263275

264-
# If the caller requested a subset of outputs, find any nodes that
265-
# are made unecessary because we were provided with an input that's
266-
# deeper into the network graph. Ignore input names that aren't
267-
# in the graph.
268-
unnecessary_nodes = set() # unordered, not iterated
269-
for input_name in iter(inputs):
270-
if graph.has_node(input_name):
271-
unnecessary_nodes |= nx.ancestors(graph, input_name)
272-
273-
# Find the nodes we need to be able to compute the requested
274-
# outputs. Raise an exception if a requested output doesn't
275-
# exist in the graph.
276-
necessary_nodes = set() # unordered, not iterated
277-
for output_name in outputs:
278-
if not graph.has_node(output_name):
279-
raise ValueError("graphkit graph does not have an output "
280-
"node named %s" % output_name)
281-
necessary_nodes |= nx.ancestors(graph, output_name)
282-
283-
# Get rid of the unnecessary nodes from the set of necessary ones.
284-
necessary_nodes -= unnecessary_nodes
285-
286-
# Drop (un-satifiable) operations with partial inputs.
276+
# Break the incoming edges to all given inputs.
277+
#
278+
# Nodes producing any given intermediate inputs are unecessary
279+
# (unless they are also used elsewhere).
280+
# To discover which ones to prune, we break their incoming edges
281+
# and they will drop out while collecting ancestors from the outputs.
282+
for given in graph_inputs:
283+
broken_dag.remove_edges_from(list(broken_dag.in_edges(given)))
284+
285+
if outputs:
286+
# If caller requested specific outputs, we can prune any
287+
# unrelated nodes further up the dag.
288+
ending_in_outputs = set()
289+
for input_name in outputs:
290+
ending_in_outputs.update(nx.ancestors(dag, input_name))
291+
broken_dag = broken_dag.subgraph(ending_in_outputs | set(outputs))
292+
293+
294+
# Prune (un-satifiable) operations with partial inputs.
287295
# See yahoo/graphkit#18
288296
#
289-
unsatisfiables = self._collect_unsatisfiable_operations(necessary_nodes, inputs)
290-
necessary_nodes -= set(unsatisfiables)
297+
unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs)
298+
shrinked_dag = dag.subgraph(broken_dag.nodes - unsatisfied)
291299

292-
shrinked_dag = graph.subgraph(necessary_nodes)
300+
plan = self._build_execution_plan(shrinked_dag)
293301

294-
return shrinked_dag
302+
return plan
295303

296304

297305
def compile(self, outputs=(), inputs=()):
298306
"""
299-
Solve dag, set the :attr:`execution_plan` and cache it.
307+
Solve dag, set the :attr:`execution_plan`, and cache it.
300308
301-
See :meth:`_solve_dag()` for description
309+
See :meth:`_solve_dag()` for detailed description.
302310
303311
:param iterable outputs:
304312
A list of desired output names. This can also be ``None``, in which
305313
case the necessary steps are all graph nodes that are reachable
306314
from one of the provided inputs.
307315
308316
:param dict inputs:
309-
The inputs names of all given inputs.
317+
The input names of all given inputs.
310318
"""
311319

312320
# return steps if it has already been computed before for this set of inputs and outputs
@@ -317,8 +325,7 @@ def compile(self, outputs=(), inputs=()):
317325
if cache_key in self._cached_execution_plans:
318326
self.execution_plan = self._cached_execution_plans[cache_key]
319327
else:
320-
dag = self._solve_dag(outputs, inputs)
321-
plan = self._build_execution_plan(dag)
328+
plan = self._solve_dag(outputs, inputs)
322329
# save this result in a precomputed cache for future lookup
323330
self.execution_plan = self._cached_execution_plans[cache_key] = plan
324331

@@ -338,16 +345,21 @@ def compute(self, outputs, named_inputs, method=None):
338345
and the values are the concrete values you
339346
want to set for the data node.
340347
348+
:param method:
349+
if ``"parallel"``, launches multi-threading.
350+
Set when invoking a composed graph or by
351+
:meth:`~NetworkOperation.set_execution_method()`.
341352
342353
:returns: a dictionary of output data objects, keyed by name.
343354
"""
344355

345356
assert isinstance(outputs, (list, tuple)) or outputs is None,\
346357
"The outputs argument must be a list"
347358

348-
# start with fresh data cache
349-
cache = {}
350-
cache.update(named_inputs)
359+
# start with fresh data cache & overwrites
360+
cache = named_inputs.copy()
361+
362+
# Build and set :attr:`execution_plan`.
351363
self.compile(outputs, named_inputs.keys())
352364

353365
# choose a method of execution

test/test_graphkit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def test_pruning_multiouts_not_override_intermediates2():
245245
operation(name="op2", needs=["d", "e"], provides=["asked"])(mul),
246246
)
247247

248-
exp = {"a": 5, "overriden": 1, "c": 2, "asked": 3}
248+
exp = {"a": 5, "overriden": 1, "c": 2, "d": 3, "e": 10, "asked": 30}
249249
# FAILs
250250
# - on v1.2.4 with (overriden, asked) = (5, 70) instead of (1, 13)
251251
# - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4.
@@ -265,12 +265,13 @@ def test_pruning_with_given_intermediate_and_asked_out():
265265
operation(name="good_op", needs=["a", "given-2"], provides=["asked"])(add),
266266
)
267267

268-
exp = {"given-1": 5, "b": 2, "given-2": 7, "a": 5, "asked": 12}
268+
exp = {"given-1": 5, "b": 2, "given-2": 2, "a": 5, "asked": 7}
269269
# v1.2.4 is ok
270270
assert netop({"given-1": 5, "b": 2, "given-2": 2}) == exp
271271
# FAILS
272272
# - on v1.2.4 with KeyError: 'a',
273273
# - on #18 (unsatisfied) with no result.
274+
# FIXED on #18+#26 (new dag solver).
274275
assert netop({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")
275276

276277

0 commit comments

Comments
 (0)