Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import abc
import functools
import importlib
import importlib.util
import json
Expand Down Expand Up @@ -792,6 +793,26 @@ def list_available_variables(
results = [Variable.from_node(n) for n in all_nodes]
return results

@functools.cached_property

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason for this to be cached? IMO, if the cache is required, would be better to have this as an internal variable explicitly instead of introducing hidden state

def variables(self) -> dict[str, Variable]:
"""Returns all variables in the graph keyed by name."""
return {
node_name: Variable.from_node(node_) for node_name, node_ in self.graph.nodes.items()
}
Comment on lines +797 to +801

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, if we are exposing get_graph, that should already give us the names of the nodes and combining it with get_variable makes the API consistent and cleaner. So maybe we remove variables for now and can always add it in case it is really desired.


def get_variable(self, name: str) -> Variable:
"""Returns a variable by name.

:param name: Name of the variable to return.
:return: Matching HamiltonNode.
:raises KeyError: If the variable does not exist in this Driver's graph.
"""
return self.variables[name]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do build this straight from self.graph so it still retains the hashmap lookup speed


def get_graph(self) -> graph_types.HamiltonGraph:
"""Returns the public HamiltonGraph representation for this Driver."""
return graph_types.HamiltonGraph.from_graph(self.graph)

@capture_function_usage
def display_all_functions(
self,
Expand Down
16 changes: 16 additions & 0 deletions tests/test_hamilton_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,22 @@ def test_driver_variables_exposes_original_function():
assert originating_functions["a"] == (tests.resources.very_simple_dag.b,) # a is an input


def test_driver_variable_lookup():
dr = Driver({}, tests.resources.very_simple_dag)

assert set(dr.variables) == {"a", "b"}
assert dr.variables["b"].name == "b"
assert dr.get_variable("a").is_external_input is True


def test_driver_get_graph_returns_hamilton_graph():
dr = Driver({}, tests.resources.very_simple_dag)

hamilton_graph = dr.get_graph()

assert hamilton_graph["b"].name == "b"


@pytest.mark.parametrize(
"driver_factory",
[
Expand Down