Skip to content

Commit 31aae2f

Browse files
committed
ENH(DAG): NEW SOLVER
+ Pruning behaves correctly also when outputs given; this happens by breaking incoming provide-links to any given intermedediate inputs. + Unsatisfied detection now includes those without outputs due to broken links (above). + Remove some uneeded "glue" from unsatisfied-detection code, leftover from previous compile() refactoring. + Renamed satisfiable --> satisfied. + Improved unknown output requested raise-message. + x3 TCs PASS, x1 in #24 and the first x2 in #25. - 1x TCs in #25 still FAIL, and need "Pinning" of given-inputs (the operation MUST and MUST NOT run in these cases).
1 parent 17eb2fd commit 31aae2f

File tree

2 files changed

+75
-71
lines changed

2 files changed

+75
-71
lines changed

graphkit/network.py

Lines changed: 73 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,12 @@ def show_layers(self, debug=False, ret=False):
137137
def _build_execution_plan(self, dag):
138138
"""
139139
Create the list of operation-nodes & *instructions* evaluating all
140-
140+
141141
operations & instructions needed a) to free memory and b) avoid
142142
overwritting given intermediate inputs.
143143
144144
:param dag:
145-
as shrinked by :meth:`compile()`
145+
the original dag but "shrinked", not "broken"
146146
147147
In the list :class:`DeleteInstructions` steps (DA) are inserted between
148148
operation nodes to reduce the memory footprint of cached results.
@@ -187,45 +187,52 @@ def _build_execution_plan(self, dag):
187187

188188
return plan
189189

190-
def _collect_unsatisfiable_operations(self, necessary_nodes, inputs):
190+
def _collect_unsatisfied_operations(self, dag, inputs):
191191
"""
192-
Traverse ordered graph and mark satisfied needs on each operation,
192+
Traverse topologically sorted dag to collect un-satisfied operations.
193+
194+
Unsatisfied operations are those suffering from ANY of the following:
195+
196+
- They are missing at least one compulsory need-input.
197+
Since the dag is ordered, as soon as we're on an operation,
198+
all its needs have been accounted, so we can get its satisfaction.
193199
194-
collecting those missing at least one.
195-
Since the graph is ordered, as soon as we're on an operation,
196-
all its needs have been accounted, so we can get its satisfaction.
200+
- Their provided outputs are not linked to any data in the dag.
201+
An operation might not have any output link when :meth:`_solve_dag()`
202+
has broken them, due to given intermediate inputs.
197203
198-
:param necessary_nodes:
199-
the subset of the graph to consider but WITHOUT the initial data
200-
(because that is what :meth:`compile()` can gives us...)
204+
:param dag:
205+
the graph to consider
201206
:param inputs:
202207
an iterable of the names of the input values
203208
return:
204-
a list of unsatisfiable operations
209+
a list of unsatisfied operations to prune
205210
"""
206-
G = self.graph # shortcut
207-
ok_data = set(inputs) # to collect producible data
208-
op_satisfaction = defaultdict(set) # to collect operation satisfiable needs
209-
unsatisfiables = [] # to collect operations with partial needs
210-
# We also need inputs to mark op_satisfaction.
211-
nodes = chain(necessary_nodes, inputs) # note that `inputs` are plain strings
212-
for node in nx.topological_sort(G.subgraph(nodes)):
211+
# To collect data that will be produced.
212+
ok_data = set(inputs)
213+
# To colect the map of operations --> satisfied-needs.
214+
op_satisfaction = defaultdict(set)
215+
# To collect the operations to drop.
216+
unsatisfied = []
217+
for node in nx.topological_sort(dag):
213218
if isinstance(node, Operation):
214219
real_needs = set(n for n in node.needs if not isinstance(n, optional))
215220
if real_needs.issubset(op_satisfaction[node]):
216-
# mark all future data-provides as ok
217-
ok_data.update(G.adj[node])
221+
# We have a satisfied operation; mark its output-data
222+
# as ok.
223+
ok_data.update(dag.adj[node])
218224
else:
219-
unsatisfiables.append(node)
225+
# Prune operations with partial inputs.
226+
unsatisfied.append(node)
220227
elif isinstance(node, (DataPlaceholderNode, str)): # `str` are givens
221228
if node in ok_data:
222229
# mark satisfied-needs on all future operations
223-
for future_op in G.adj[node]:
230+
for future_op in dag.adj[node]:
224231
op_satisfaction[future_op].add(node)
225232
else:
226233
raise AssertionError("Unrecognized network graph node %r" % node)
227234

228-
return unsatisfiables
235+
return unsatisfied
229236

230237

231238
def _solve_dag(self, outputs, inputs):
@@ -245,68 +252,64 @@ def _solve_dag(self, outputs, inputs):
245252
The inputs names of all given inputs.
246253
247254
:return:
248-
the subgraph comprising the solution
249-
255+
the *execution plan*
250256
"""
251-
graph = self.graph
252-
if not outputs:
257+
dag = self.graph
253258

254-
# If caller requested all outputs, the necessary nodes are all
255-
# nodes that are reachable from one of the inputs. Ignore input
256-
# names that aren't in the graph.
257-
necessary_nodes = set() # unordered, not iterated
258-
for input_name in iter(inputs):
259-
if graph.has_node(input_name):
260-
necessary_nodes |= nx.descendants(graph, input_name)
259+
# Ignore input names that aren't in the graph.
260+
graph_inputs = iset(dag.nodes) & inputs # preserve order
261261

262-
else:
262+
# Scream if some requested outputs aren't in the graph.
263+
unknown_outputs = iset(outputs) - dag.nodes
264+
if unknown_outputs:
265+
raise ValueError(
266+
"Unknown output node(s) requested: %s"
267+
% ", ".join(unknown_outputs))
268+
269+
broken_dag = dag.copy() # preserve net's graph
263270

264-
# If the caller requested a subset of outputs, find any nodes that
265-
# are made unecessary because we were provided with an input that's
266-
# deeper into the network graph. Ignore input names that aren't
267-
# in the graph.
268-
unnecessary_nodes = set() # unordered, not iterated
269-
for input_name in iter(inputs):
270-
if graph.has_node(input_name):
271-
unnecessary_nodes |= nx.ancestors(graph, input_name)
272-
273-
# Find the nodes we need to be able to compute the requested
274-
# outputs. Raise an exception if a requested output doesn't
275-
# exist in the graph.
276-
necessary_nodes = set() # unordered, not iterated
277-
for output_name in outputs:
278-
if not graph.has_node(output_name):
279-
raise ValueError("graphkit graph does not have an output "
280-
"node named %s" % output_name)
281-
necessary_nodes |= nx.ancestors(graph, output_name)
282-
283-
# Get rid of the unnecessary nodes from the set of necessary ones.
284-
necessary_nodes -= unnecessary_nodes
285-
286-
# Drop (un-satifiable) operations with partial inputs.
271+
# Break the incoming edges to all given inputs.
272+
#
273+
# Nodes producing any given intermediate inputs are unecessary
274+
# (unless they are also used elsewhere).
275+
# To discover which ones to prune, we break their incoming edges
276+
# and they will drop out while collecting ancestors from the outputs.
277+
for given in graph_inputs:
278+
broken_dag.remove_edges_from(list(broken_dag.in_edges(given)))
279+
280+
if outputs:
281+
# If caller requested specific outputs, we can prune any
282+
# unrelated nodes further up the dag.
283+
ending_in_outputs = set()
284+
for input_name in outputs:
285+
ending_in_outputs.update(nx.ancestors(dag, input_name))
286+
broken_dag = broken_dag.subgraph(ending_in_outputs | set(outputs))
287+
288+
289+
# Prune (un-satifiable) operations with partial inputs.
287290
# See yahoo/graphkit#18
288291
#
289-
unsatisfiables = self._collect_unsatisfiable_operations(necessary_nodes, inputs)
290-
necessary_nodes -= set(unsatisfiables)
292+
unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs)
293+
shrinked_dag = dag.subgraph(broken_dag.nodes - unsatisfied)
291294

292-
shrinked_dag = graph.subgraph(necessary_nodes)
295+
plan = self._build_execution_plan(shrinked_dag)
293296

294-
return shrinked_dag
297+
return plan
295298

296299

297300
def compile(self, outputs=(), inputs=()):
298301
"""
299-
Solve dag, set the :attr:`execution_plan` and cache it.
302+
Solve dag, set the :attr:`execution_plan`, and cache it.
300303
301-
See :meth:`_solve_dag()` for description
304+
See :meth:`_solve_dag()` for detailed description.
302305
303306
:param iterable outputs:
304307
A list of desired output names. This can also be ``None``, in which
305308
case the necessary steps are all graph nodes that are reachable
306309
from one of the provided inputs.
307310
308311
:param dict inputs:
309-
The inputs names of all given inputs.
312+
The input names of all given inputs.
310313
"""
311314

312315
# return steps if it has already been computed before for this set of inputs and outputs
@@ -317,8 +320,7 @@ def compile(self, outputs=(), inputs=()):
317320
if cache_key in self._cached_execution_plans:
318321
self.execution_plan = self._cached_execution_plans[cache_key]
319322
else:
320-
dag = self._solve_dag(outputs, inputs)
321-
plan = self._build_execution_plan(dag)
323+
plan = self._solve_dag(outputs, inputs)
322324
# save this result in a precomputed cache for future lookup
323325
self.execution_plan = self._cached_execution_plans[cache_key] = plan
324326

@@ -345,9 +347,10 @@ def compute(self, outputs, named_inputs, method=None):
345347
assert isinstance(outputs, (list, tuple)) or outputs is None,\
346348
"The outputs argument must be a list"
347349

348-
# start with fresh data cache
349-
cache = {}
350-
cache.update(named_inputs)
350+
# start with fresh data cache & overwrites
351+
cache = named_inputs.copy()
352+
353+
# Build and set :attr:`execution_plan`.
351354
self.compile(outputs, named_inputs.keys())
352355

353356
# choose a method of execution

test/test_graphkit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,13 @@ def test_pruning_with_given_intermediate_and_asked_out():
265265
operation(name="good_op", needs=["a", "given-2"], provides=["asked"])(add),
266266
)
267267

268-
exp = {"given-1": 5, "b": 2, "given-2": 7, "a": 5, "asked": 12}
268+
exp = {"given-1": 5, "b": 2, "given-2": 2, "a": 5, "asked": 7}
269269
# v1.2.4 is ok
270270
assert netop({"given-1": 5, "b": 2, "given-2": 2}) == exp
271271
# FAILS
272272
# - on v1.2.4 with KeyError: 'a',
273273
# - on #18 (unsatisfied) with no result.
274+
# FIXED on #18+#26 (new dag solver).
274275
assert netop({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")
275276

276277

0 commit comments

Comments
 (0)