# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import abc
import importlib
import importlib.util
import json
import logging
import operator
import pathlib
import sys
# required if we want to run this code stand alone.
import typing
import uuid
import warnings
from collections.abc import Callable, Collection, Sequence
from datetime import datetime
from types import ModuleType
from typing import (
Any,
Literal,
Optional,
)
import pandas as pd
from typing_extensions import Self
from hamilton import common, graph_types, htypes
from hamilton.caching.adapter import HamiltonCacheAdapter
from hamilton.caching.stores.base import MetadataStore, ResultStore
from hamilton.dev_utils import deprecation
from hamilton.execution import executors, graph_functions, grouping, state
from hamilton.function_modifiers.validation import DISABLE_DATA_QUALITY_CHECKS_CONFIG_KEY
from hamilton.graph_types import HamiltonNode
from hamilton.io import materialization
from hamilton.io.materialization import ExtractorFactory, MaterializerFactory
from hamilton.lifecycle import base as lifecycle_base
SLACK_ERROR_MESSAGE = (
"-------------------------------------------------------------------\n"
"Oh no an error! Need help with Hamilton?\n"
"Join our slack and ask for help! https://join.slack.com/t/hamilton-opensource/shared_invite/zt-2niepkra8-DGKGf_tTYhXuJWBTXtIs4g\n"
"-------------------------------------------------------------------\n"
)
if __name__ == "__main__":
import base
import graph
import node
else:
from . import base, graph, node
logger = logging.getLogger(__name__)
def capture_function_usage(call_fn: Callable) -> Callable:
"""No-op decorator kept for backwards compatibility.
:param call_fn: the Driver function.
:return: the same function, unwrapped.
"""
return call_fn
# This is kept in here for backwards compatibility
# You will want to refer to graph_types.HamiltonNode
Variable = graph_types.HamiltonNode
class InvalidExecutorException(Exception):
"""Raised when the executor is invalid for the given graph."""
pass
class GraphExecutor(abc.ABC):
"""Interface for the graph executor. This runs a function graph,
given a function graph, inputs, and overrides."""
@abc.abstractmethod
def execute(
self,
fg: graph.FunctionGraph,
final_vars: list[str | Callable | Variable],
overrides: dict[str, Any],
inputs: dict[str, Any],
run_id: str,
) -> dict[str, Any]:
"""Executes a graph in a blocking function.
:param fg: Graph to execute
:param final_vars: Variables we want
:param overrides: Overrides --- these short-circuit computation
:param inputs: Inputs to the Graph.
:param adapter: Adapter to use for execution (optional).
:param run_id: Run ID for the DAG run.
:return: The output of the final variables, in dictionary form.
"""
pass
@abc.abstractmethod
def validate(self, nodes_to_execute: list[node.Node]):
"""Validates whether the executor can execute the given graph.
Some executors allow API constructs that others do not support
(such as Parallelizable[]/Collect[])
:param fg: Graph to execute
:return: Whether or not the executor can execute the graph.
"""
pass
[docs]
class DefaultGraphExecutor(GraphExecutor):
DEFAULT_TASK_NAME = "root" # Not task-based, so we just assign a default name for a task
[docs]
def __init__(self, adapter: lifecycle_base.LifecycleAdapterSet | None = None):
"""Constructor for the default graph executor.
:param adapter: Adapter to use for execution (optional).
"""
self.adapter = adapter
[docs]
def validate(self, nodes_to_execute: list[node.Node]):
"""The default graph executor cannot handle parallelizable[]/collect[] nodes.
:param nodes_to_execute:
:raises InvalidExecutorException: if the graph contains parallelizable[]/collect[] nodes.
"""
for node_ in nodes_to_execute:
if node_.node_role in (node.NodeType.EXPAND, node.NodeType.COLLECT):
raise InvalidExecutorException(
f"Default graph executor cannot handle parallelizable[]/collect[] nodes. "
f"Node {node_.name} defined by functions: "
f"{[fn.__qualname__ for fn in node_.originating_functions]}"
)
[docs]
def execute(
self,
fg: graph.FunctionGraph,
final_vars: list[str],
overrides: dict[str, Any],
inputs: dict[str, Any],
run_id: str,
) -> dict[str, Any]:
"""Basic executor for a function graph. Does no task-based execution, just does a DFS
and executes the graph in order, in memory."""
memoized_computation = dict() # memoized storage
nodes = [fg.nodes[node_name] for node_name in final_vars if node_name in fg.nodes]
fg.execute(nodes, memoized_computation, overrides, inputs, run_id=run_id)
outputs = {
# we do this here to enable inputs to also be used as outputs
# putting inputs into memoized before execution doesn't work due to some graphadapter assumptions.
final_var: memoized_computation.get(final_var, inputs.get(final_var))
for final_var in final_vars
} # only want request variables in df.
del memoized_computation # trying to cleanup some memory
return outputs
[docs]
class TaskBasedGraphExecutor(GraphExecutor):
[docs]
def validate(self, nodes_to_execute: list[node.Node]):
"""Currently this can run every valid graph"""
pass
[docs]
def __init__(
self,
execution_manager: executors.ExecutionManager,
grouping_strategy: grouping.GroupingStrategy,
adapter: lifecycle_base.LifecycleAdapterSet,
):
"""Executor for task-based execution. This enables grouping of nodes into tasks, as
well as parallel execution/dynamic spawning of nodes.
:param execution_manager: Utility to assign task executors to node groups
:param grouping_strategy: Utility to group nodes into tasks
:param result_builder: Utility to build the final result"""
self.execution_manager = execution_manager
self.grouping_strategy = grouping_strategy
self.adapter = adapter
[docs]
def execute(
self,
fg: graph.FunctionGraph,
final_vars: list[str],
overrides: dict[str, Any],
inputs: dict[str, Any],
run_id: str,
) -> dict[str, Any]:
"""Executes a graph, task by task. This blocks until completion.
This does the following:
1. Groups the nodes into tasks
2. Creates an execution state and a results cache
3. Runs it to completion, populating the results cache
4. Returning the results from the results cache
"""
inputs = graph_functions.combine_config_and_inputs(fg.config, inputs)
(
transform_nodes_required_for_execution,
user_defined_nodes_required_for_execution,
) = fg.get_upstream_nodes(final_vars, runtime_inputs=inputs, runtime_overrides=overrides)
all_nodes_required_for_execution = list(
set(transform_nodes_required_for_execution).union(
user_defined_nodes_required_for_execution
)
)
grouped_nodes = self.grouping_strategy.group_nodes(
all_nodes_required_for_execution
) # pure function transform
# Instantiate a result cache so we can use later
# Pass in inputs so we can pre-populate the results cache
prehydrated_results = {**overrides, **inputs}
results_cache = state.DictBasedResultCache(prehydrated_results)
# Create tasks from the grouped nodes, filtering/pruning as we go
tasks = grouping.create_task_plan(grouped_nodes, final_vars, overrides, self.adapter)
task_ids = [task.base_id for task in tasks]
if self.adapter.does_hook("post_task_group", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"post_task_group",
run_id=run_id,
task_ids=task_ids,
)
# Create a task graph and execution state
execution_state = state.ExecutionState(
tasks, results_cache, run_id
) # Stateful storage for the DAG
# Blocking call to run through until completion
executors.run_graph_to_completion(execution_state, self.execution_manager)
# Read the final variables from the result cache
raw_result = results_cache.read(final_vars)
return raw_result
[docs]
class Driver:
"""This class orchestrates creating and executing the DAG to create a dataframe.
.. code-block:: python
from hamilton import driver
from hamilton import base
# 1. Setup config or invariant input.
config = {}
# 2. we need to tell hamilton where to load function definitions from
import my_functions
# or programmatically (e.g. you can script module loading)
module_name = "my_functions"
my_functions = importlib.import_module(module_name)
# 3. Determine the return type -- default is a pandas.DataFrame.
adapter = base.SimplePythonDataFrameGraphAdapter() # See GraphAdapter docs for more details.
# These all feed into creating the driver & thus DAG.
dr = driver.Driver(config, module, adapter=adapter)
"""
def __getstate__(self):
"""Used for serialization."""
# Copy the object's state from self.__dict__
state = self.__dict__.copy()
# Remove the unpicklable entries -- right now it's the modules tracked.
state["__graph_module_names"] = [
importlib.util.find_spec(m.__name__).name for m in state["graph_modules"]
]
del state["graph_modules"] # remove from state
return state
def __setstate__(self, state):
"""Used for deserialization."""
# Restore instance attributes
self.__dict__.update(state)
# Reinitialize the unpicklable entries
# assumption is that the modules are importable in the new process
self.graph_modules = []
for n in state["__graph_module_names"]:
try:
g_module = importlib.import_module(n)
except ImportError:
logger.error(f"Could not import module {n}")
continue
else:
self.graph_modules.append(g_module)
@staticmethod
def _perform_graph_validations(
adapter: lifecycle_base.LifecycleAdapterSet,
graph: graph.FunctionGraph,
graph_modules: typing.Sequence[ModuleType],
):
"""Utility function to perform graph validations. Static so we're not stuck with local state
:param adapter: Adapter to use for validation.
:param graph: Graph to validate.
:param graph_modules: Modules to validate.
"""
if adapter.does_validation("validate_node"):
node_validation_results = {}
for node_ in graph.nodes.values():
validation_results = [
result.error
for result in adapter.call_all_validators_sync(
"validate_node",
output_only_failures=True, # just failures so we can just store messages
created_node=node_,
)
if not result.success
]
if len(validation_results) > 0:
node_validation_results[node_.name] = validation_results
# if any have failed
if len(node_validation_results) > 0:
error_delimiter = "\n\t" # expression fragments not allowed in f-strings
errors = [
f"{node_name}: {error_delimiter.join([item for item in messages])}"
for node_name, messages in sorted(
node_validation_results.items(), key=operator.itemgetter(0)
)
]
error_str = (
f"Node validation failed! {len(errors)} errors encountered:\n "
+ "\n ".join(errors)
)
raise lifecycle_base.ValidationException(error_str)
if adapter.does_validation("validate_graph"):
validation_results = adapter.call_all_validators_sync(
"validate_graph",
output_only_failures=True,
graph=graph,
modules=graph_modules,
config=graph.config,
)
if validation_results:
error_delimiter = "\t"
errors = sorted([result.error for result in validation_results])
error_str = (
f"Graph validation failed! {len(errors)} errors encountered:{error_delimiter}"
+ error_delimiter.join(errors)
)
raise lifecycle_base.ValidationException(error_str)
[docs]
def __init__(
self,
config: dict[str, Any],
*modules: ModuleType,
adapter: lifecycle_base.LifecycleAdapter
| list[lifecycle_base.LifecycleAdapter]
| None = None,
allow_module_overrides: bool = False,
_materializers: typing.Sequence[ExtractorFactory | MaterializerFactory] = None,
_graph_executor: GraphExecutor = None,
_use_legacy_adapter: bool = True,
):
"""Constructor: creates a DAG given the configuration & modules to crawl.
:param config: This is a dictionary of initial data & configuration.
The contents are used to help create the DAG.
:param modules: Python module objects you want to inspect for Hamilton Functions.
:param adapter: Optional. A way to wire in another way of "executing" a hamilton graph.
Defaults to using original Hamilton adapter which is single threaded in memory python.
:param allow_module_overrides: Optional. Same named functions get overridden by later modules.
The order of listing the modules is important, since later ones will overwrite the previous ones.
This is a global call affecting all imported modules.
See https://github.com/apache/hamilton/tree/main/examples/module_overrides for more info.
:param _materializers: Not public facing, do not use this parameter. This is injected by the builder.
:param _graph_executor: Not public facing, do not use this parameter. This is injected by the builder.
If you need to tune execution, use the builder to do so.
:param _use_legacy_adapter: Not public facing, do not use this parameter.
This represents whether or not to use the legacy adapter. Defaults to True, as this should be
backwards compatible. In Hamilton 2.0.0, this will be removed.
"""
self.driver_run_id = uuid.uuid4()
adapter = self.normalize_adapter_input(adapter, use_legacy_adapter=_use_legacy_adapter)
if adapter.does_hook("pre_do_anything", is_async=False):
adapter.call_all_lifecycle_hooks_sync("pre_do_anything")
self.graph_modules = modules
try:
self.graph = graph.FunctionGraph.from_modules(
*modules,
config=config,
adapter=adapter,
allow_module_overrides=allow_module_overrides,
)
if _materializers:
materializer_factories, extractor_factories = self._process_materializers(
_materializers
)
self.graph = materialization.modify_graph(
self.graph, materializer_factories, extractor_factories
)
Driver._perform_graph_validations(adapter, graph=self.graph, graph_modules=modules)
if adapter.does_hook("post_graph_construct", is_async=False):
adapter.call_all_lifecycle_hooks_sync(
"post_graph_construct", graph=self.graph, modules=modules, config=config
)
self.adapter = adapter
if _graph_executor is None:
_graph_executor = DefaultGraphExecutor(self.adapter)
self.graph_executor = _graph_executor
self.config = config
except Exception as e:
logger.error(SLACK_ERROR_MESSAGE)
raise e
def _repr_mimebundle_(self, include=None, exclude=None, **kwargs):
"""Attribute read by notebook renderers
This returns the attribute of the `graphviz.Digraph` returned by `self.display_all_functions()`
The parameters `include`, `exclude`, and `**kwargs` are required, but not explicitly used
ref: https://ipython.readthedocs.io/en/stable/config/integrating.html
"""
dot = self.display_all_functions()
return dot._repr_mimebundle_(include=include, exclude=exclude, **kwargs)
[docs]
def execute(
self,
final_vars: list[str | Callable | Variable],
overrides: dict[str, Any] = None,
display_graph: bool = False,
inputs: dict[str, Any] = None,
) -> Any:
"""Executes computation.
:param final_vars: the final list of outputs we want to compute.
:param overrides: values that will override "nodes" in the DAG.
:param display_graph: DEPRECATED. Whether we want to display the graph being computed.
:param inputs: Runtime inputs to the DAG.
:return: an object consisting of the variables requested, matching the type returned by the GraphAdapter.
See constructor for how the GraphAdapter is initialized. The default one right now returns a pandas
dataframe.
"""
if display_graph:
warnings.warn(
"display_graph=True is deprecated. It will be removed in the 2.0.0 release. "
"Please use visualize_execution().",
DeprecationWarning,
stacklevel=2,
)
run_id = str(uuid.uuid4())
run_successful = True
error_execution = None
outputs = None
_final_vars = self._create_final_vars(final_vars)
if self.adapter.does_hook("pre_graph_execute", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"pre_graph_execute",
run_id=run_id,
graph=self.graph,
final_vars=_final_vars,
inputs=inputs,
overrides=overrides,
)
try:
outputs = self.__raw_execute(
_final_vars, overrides, display_graph, inputs=inputs, _run_id=run_id
)
if self.adapter.does_method("do_build_result", is_async=False):
# Build the result if we have a result builder
outputs = self.adapter.call_lifecycle_method_sync(
"do_build_result", outputs=outputs
)
# Otherwise just return a dict
except Exception as e:
run_successful = False
logger.error(SLACK_ERROR_MESSAGE)
error_execution = e
raise e
finally:
if self.adapter.does_hook("post_graph_execute", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"post_graph_execute",
run_id=run_id,
graph=self.graph,
success=run_successful,
error=error_execution,
results=outputs,
)
return outputs
def _create_final_vars(self, final_vars: list[str | Callable | Variable]) -> list[str]:
"""Creates the final variables list - converting functions names as required.
:param final_vars:
:return: list of strings in the order that final_vars was provided.
"""
_module_set = {_module.__name__ for _module in self.graph_modules}
_final_vars = common.convert_output_values(final_vars, _module_set)
return _final_vars
[docs]
@deprecation.deprecated(
warn_starting=(1, 0, 0),
fail_starting=(2, 0, 0),
use_this=None,
explanation="This has become a private method and does not guarantee that all the adapters work correctly.",
migration_guide="Don't use this entry point for execution directly. Always go through `.execute()`or `.materialize()`.",
)
def raw_execute(
self,
final_vars: list[str],
overrides: dict[str, Any] = None,
display_graph: bool = False,
inputs: dict[str, Any] = None,
_fn_graph: graph.FunctionGraph = None,
) -> dict[str, Any]:
"""Raw execute function that does the meat of execute.
Don't use this entry point for execution directly. Always go through `.execute()` or `.materialize()`.
In case you are using `.raw_execute()` directly, please switch to `.execute()` using a
`base.DictResult()`. Note: `base.DictResult()` is the default return of execute if you are
using the `driver.Builder()` class to create a `Driver()` object.
:param final_vars: Final variables to compute
:param overrides: Overrides to run.
:param display_graph: DEPRECATED. DO NOT USE. Whether or not to display the graph when running it
:param inputs: Runtime inputs to the DAG
:return:
"""
function_graph = _fn_graph if _fn_graph is not None else self.graph
run_id = str(uuid.uuid4())
nodes, user_nodes = function_graph.get_upstream_nodes(final_vars, inputs, overrides)
Driver.validate_inputs(
function_graph, self.adapter, user_nodes, inputs, nodes
) # TODO -- validate within the function graph itself
if display_graph: # deprecated flow.
warnings.warn(
"display_graph=True is deprecated. It will be removed in the 2.0.0 release. "
"Please use visualize_execution().",
DeprecationWarning,
stacklevel=2,
)
self.visualize_execution(final_vars, "test-output/execute.gv", {"view": True})
if self.has_cycles(
final_vars, function_graph
): # here for backwards compatible driver behavior.
raise ValueError("Error: cycles detected in your graph.")
all_nodes = nodes | user_nodes
self.graph_executor.validate(list(all_nodes))
if self.adapter.does_hook("pre_graph_execute", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"pre_graph_execute",
run_id=run_id,
graph=function_graph,
final_vars=final_vars,
inputs=inputs,
overrides=overrides,
)
results = None
error = None
success = False
try:
results = self.graph_executor.execute(
function_graph,
final_vars,
overrides if overrides is not None else {},
inputs if inputs is not None else {},
run_id,
)
success = True
except Exception as e:
error = e
success = False
raise e
finally:
if self.adapter.does_hook("post_graph_execute", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"post_graph_execute",
run_id=run_id,
graph=function_graph,
success=success,
error=error,
results=results,
)
return results
def __raw_execute(
self,
final_vars: list[str],
overrides: dict[str, Any] = None,
display_graph: bool = False,
inputs: dict[str, Any] = None,
_fn_graph: graph.FunctionGraph = None,
_run_id: str = None,
) -> dict[str, Any]:
"""Raw execute function that does the meat of execute.
Private method since the result building and post_graph_execute lifecycle hooks are performed outside and so this returns an incomplete result.
:param final_vars: Final variables to compute
:param overrides: Overrides to run.
:param display_graph: DEPRECATED. DO NOT USE. Whether or not to display the graph when running it
:param inputs: Runtime inputs to the DAG
:return:
"""
function_graph = _fn_graph if _fn_graph is not None else self.graph
run_id = _run_id
nodes, user_nodes = function_graph.get_upstream_nodes(final_vars, inputs, overrides)
Driver.validate_inputs(
function_graph, self.adapter, user_nodes, inputs, nodes
) # TODO -- validate within the function graph itself
if display_graph: # deprecated flow.
warnings.warn(
"display_graph=True is deprecated. It will be removed in the 2.0.0 release. "
"Please use visualize_execution().",
DeprecationWarning,
stacklevel=2,
)
self.visualize_execution(final_vars, "test-output/execute.gv", {"view": True})
if self.has_cycles(
final_vars, function_graph
): # here for backwards compatible driver behavior.
raise ValueError("Error: cycles detected in your graph.")
all_nodes = nodes | user_nodes
self.graph_executor.validate(list(all_nodes))
results = None
try:
results = self.graph_executor.execute(
function_graph,
final_vars,
overrides if overrides is not None else {},
inputs if inputs is not None else {},
run_id,
)
return results
except Exception as e:
raise e
[docs]
@capture_function_usage
def list_available_variables(
self, *, tag_filter: dict[str, str | None | list[str]] = None
) -> list[Variable]:
"""Returns available variables, i.e. outputs.
These variables correspond 1:1 with nodes in the DAG, and contain the following information:
1. name: the name of the node
2. tags: the tags associated with this node
3. type: The type of data this node returns
4. is_external_input: Whether this node represents an external input (required from outside), \
or not (has a function specifying its behavior).
.. code-block:: python
# gets all
dr.list_available_variables()
# gets exact matching tag name and tag value
dr.list_available_variables({"TAG_NAME": "TAG_VALUE"})
# gets all matching tag name and at least one of the values in the list
dr.list_available_variables({"TAG_NAME": ["TAG_VALUE1", "TAG_VALUE2"]})
# gets all with matching tag name, irrespective of value
dr.list_available_variables({"TAG_NAME": None})
# AND query between the two tags (i.e. both need to match)
dr.list_available_variables({"TAG_NAME": "TAG_VALUE", "TAG_NAME2": "TAG_VALUE2"}
:param tag_filter: A dictionary of tags to filter by. Only nodes matching the tags and their values will
be returned. If the value for a tag is None, then we will return all nodes with that tag. If the value
is non-empty we will return all nodes with that tag and that value.
:return: list of available variables (i.e. outputs).
"""
all_nodes = self.graph.get_nodes()
if tag_filter:
valid_filter_values = all(
map(
lambda x: (
isinstance(x, str) or (isinstance(x, list) and len(x) != 0) or x is None
),
tag_filter.values(),
)
)
if not valid_filter_values:
raise ValueError("All tag query values must be a string or list of strings")
results = []
for n in all_nodes:
if node.matches_query(n.tags, tag_filter):
results.append(Variable.from_node(n))
else:
results = [Variable.from_node(n) for n in all_nodes]
return results
[docs]
@capture_function_usage
def display_all_functions(
self,
output_file_path: str = None,
render_kwargs: dict = None,
graphviz_kwargs: dict = None,
show_legend: bool = True,
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Displays the graph of all functions loaded!
:param output_file_path: the full URI of path + file name to save the dot file to.
E.g. 'some/path/graph-all.dot'. Optional. No need to pass it in if you're in a Jupyter Notebook.
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
See https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Graph.render for other options.
:param graphviz_kwargs: Optional. Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
See https://graphviz.org/doc/info/attrs.html for options.
:param show_legend: If True, add a legend to the visualization based on the DAG's nodes.
:param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL.
`orient` will be overwridden by the value of `graphviz_kwargs['graph_attr']['rankdir']`
see (https://graphviz.org/docs/attr-types/rankdir/)
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if
the nodes have schema data provided
:param custom_style_function: Optional. Custom style function. See example in repository for example use.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
try:
return self.graph.display_all(
output_file_path,
render_kwargs,
graphviz_kwargs,
show_legend=show_legend,
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
keep_dot=keep_dot,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
@staticmethod
def _visualize_execution_helper(
fn_graph: graph.FunctionGraph,
adapter: lifecycle_base.LifecycleAdapterSet,
final_vars: list[str],
output_file_path: str,
render_kwargs: dict,
inputs: dict[str, Any] = None,
graphviz_kwargs: dict = None,
overrides: dict[str, Any] = None,
show_legend: bool = True,
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
bypass_validation: bool = False,
keep_dot: bool = False,
):
"""Helper function to visualize execution, using a passed-in function graph.
:param final_vars: The final variables to compute.
:param output_file_path: The path to save the graph to.
:param render_kwargs: The kwargs to pass to the graphviz render function.
:param inputs: The inputs to the DAG.
:param graphviz_kwargs: The kwargs to pass to the graphviz graph object.
:param show_legend: If True, add a legend to the visualization based on the DAG's nodes.
:param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL.
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
"""
# TODO should determine if the visualization logic should live here or in the graph.py module
nodes, user_nodes = fn_graph.get_upstream_nodes(final_vars, inputs, overrides)
if not bypass_validation:
try:
Driver.validate_inputs(fn_graph, adapter, user_nodes, inputs, nodes)
except ValueError as e:
# Python 3.11 enables the more succinct `.add_note()` syntax
error_note = "Use `bypass_validation=True` to skip validation"
if e.args:
e.args = (f"{e.args[0]}; {error_note}",) + e.args[1:]
else:
e.args = (error_note,)
raise e
node_modifiers = {fv: {graph.VisualizationNodeModifiers.IS_OUTPUT} for fv in final_vars}
for user_node in user_nodes:
if user_node.name not in node_modifiers:
node_modifiers[user_node.name] = set()
node_modifiers[user_node.name].add(graph.VisualizationNodeModifiers.IS_USER_INPUT)
all_nodes = nodes | user_nodes
if overrides is not None:
for node_ in all_nodes:
if node_.name in overrides:
# We don't want to display it if we're overriding it
# This is necessary as getting upstream nodes includes overrides
if node_.name not in node_modifiers:
node_modifiers[node_.name] = set()
node_modifiers[node_.name].add(graph.VisualizationNodeModifiers.IS_OVERRIDE)
try:
return fn_graph.display(
all_nodes,
output_file_path=output_file_path,
render_kwargs=render_kwargs,
graphviz_kwargs=graphviz_kwargs,
node_modifiers=node_modifiers,
strictly_display_only_passed_in_nodes=True,
show_legend=show_legend,
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=fn_graph._config,
keep_dot=keep_dot,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
[docs]
@capture_function_usage
def visualize_execution(
self,
final_vars: list[str | Callable | Variable],
output_file_path: str = None,
render_kwargs: dict = None,
inputs: dict[str, Any] = None,
graphviz_kwargs: dict = None,
overrides: dict[str, Any] = None,
show_legend: bool = True,
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
bypass_validation: bool = False,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes Execution.
Note: overrides are not handled at this time.
Shapes:
- ovals are nodes/functions
- rectangles are nodes/functions that are requested as output
- shapes with dotted lines are inputs required to run the DAG.
:param final_vars: the outputs we want to compute. They will become rectangles in the graph.
:param output_file_path: the full URI of path + file name to save the dot file to.
E.g. 'some/path/graph.dot'. Optional. No need to pass it in if you're in a Jupyter Notebook.
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
See https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Graph.render for other options.
:param inputs: Optional. Runtime inputs to the DAG.
:param graphviz_kwargs: Optional. Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
See https://graphviz.org/doc/info/attrs.html for options.
:param overrides: Optional. Overrides to the DAG.
:param show_legend: If True, add a legend to the visualization based on the DAG's nodes.
:param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL.
`orient` will be overwridden by the value of `graphviz_kwargs['graph_attr']['rankdir']`
see (https://graphviz.org/docs/attr-types/rankdir/)
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
_final_vars = self._create_final_vars(final_vars)
return self._visualize_execution_helper(
self.graph,
self.adapter,
_final_vars,
output_file_path,
render_kwargs,
inputs,
graphviz_kwargs,
overrides,
show_legend=show_legend,
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
show_schema=show_schema,
custom_style_function=custom_style_function,
bypass_validation=bypass_validation,
keep_dot=keep_dot,
)
[docs]
@capture_function_usage
def export_execution(
self,
final_vars: list[str],
inputs: dict[str, Any] = None,
overrides: dict[str, Any] = None,
) -> str:
"""Method to create JSON representation of the Graph.
:param final_vars: The final variables to compute.
:param inputs: Optional. The inputs to the DAG.
:param overrides: Optional. Overrides to the DAG.
:return: JSON string representation of the graph.
"""
nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs, overrides)
Driver.validate_inputs(self.graph, self.adapter, user_nodes, inputs, nodes)
all_nodes = nodes | user_nodes
hamilton_nodes = [HamiltonNode.from_node(n).as_dict() for n in all_nodes]
sorted_nodes = sorted(hamilton_nodes, key=lambda x: x["name"])
return json.dumps({"nodes": sorted_nodes})
[docs]
@capture_function_usage
def has_cycles(
self,
final_vars: list[str | Callable | Variable],
_fn_graph: graph.FunctionGraph = None,
) -> bool:
"""Checks that the created graph does not have cycles.
:param final_vars: the outputs we want to compute.
:param _fn_graph: the function graph to check for cycles, used internally
:return: boolean True for cycles, False for no cycles.
"""
function_graph = _fn_graph if _fn_graph is not None else self.graph
_final_vars = self._create_final_vars(final_vars)
# get graph we'd be executing over
nodes, user_nodes = function_graph.get_upstream_nodes(_final_vars)
return self.graph.has_cycles(nodes, user_nodes)
[docs]
@capture_function_usage
def what_is_downstream_of(self, *node_names: str) -> list[Variable]:
"""Tells you what is downstream of this function(s), i.e. node(s).
:param node_names: names of function(s) that are starting points for traversing the graph.
:return: list of "variables" (i.e. nodes), inclusive of the function names, that are downstream of the passed
in function names.
"""
downstream_nodes = self.graph.get_downstream_nodes(list(node_names))
return [Variable.from_node(n) for n in downstream_nodes]
[docs]
@capture_function_usage
def display_downstream_of(
self,
*node_names: str,
output_file_path: str = None,
render_kwargs: dict = None,
graphviz_kwargs: dict = None,
show_legend: bool = True,
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Creates a visualization of the DAG starting from the passed in function name(s).
Note: for any "node" visualized, we will also add its parents to the visualization as well, so
there could be more nodes visualized than strictly what is downstream of the passed in function name(s).
:param node_names: names of function(s) that are starting points for traversing the graph.
:param output_file_path: the full URI of path + file name to save the dot file to.
E.g. 'some/path/graph.dot'. Optional. No need to pass it in if you're in a Jupyter Notebook.
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
:param graphviz_kwargs: Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
:param show_legend: If True, add a legend to the visualization based on the DAG's nodes.
:param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL.
`orient` will be overwridden by the value of `graphviz_kwargs['graph_attr']['rankdir']`
see (https://graphviz.org/docs/attr-types/rankdir/)
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
downstream_nodes = self.graph.get_downstream_nodes(list(node_names))
nodes_to_display = set()
for n in downstream_nodes:
nodes_to_display.add(n)
for d in n.dependencies:
if d not in downstream_nodes:
nodes_to_display.add(d)
try:
return self.graph.display(
nodes_to_display,
output_file_path,
render_kwargs=render_kwargs,
graphviz_kwargs=graphviz_kwargs,
strictly_display_only_passed_in_nodes=True,
show_legend=show_legend,
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=self.graph._config,
keep_dot=keep_dot,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
[docs]
@capture_function_usage
def display_upstream_of(
self,
*node_names: str,
output_file_path: str = None,
render_kwargs: dict = None,
graphviz_kwargs: dict = None,
show_legend: bool = True,
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Creates a visualization of the DAG going backwards from the passed in function name(s).
Note: for any "node" visualized, we will also add its parents to the visualization as well, so
there could be more nodes visualized than strictly what is upstream of the passed in function name(s).
:param node_names: names of function(s) that are starting points for traversing the graph.
:param output_file_path: the full URI of path + file name to save the dot file to.
E.g. 'some/path/graph.dot'. Optional. No need to pass it in if you're in a Jupyter Notebook.
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`. Optional.
:param graphviz_kwargs: Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image. Optional.
:param show_legend: If True, add a legend to the visualization based on the DAG's nodes.
:param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL.
`orient` will be overwridden by the value of `graphviz_kwargs['graph_attr']['rankdir']`
see (https://graphviz.org/docs/attr-types/rankdir/)
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
upstream_nodes, user_nodes = self.graph.get_upstream_nodes(list(node_names))
node_modifiers = {}
for n in user_nodes:
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}
try:
return self.graph.display(
upstream_nodes,
output_file_path,
render_kwargs=render_kwargs,
graphviz_kwargs=graphviz_kwargs,
strictly_display_only_passed_in_nodes=True,
node_modifiers=node_modifiers,
show_legend=show_legend,
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=self.graph._config,
keep_dot=keep_dot,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
[docs]
@capture_function_usage
def what_is_upstream_of(self, *node_names: str) -> list[Variable]:
"""Tells you what is upstream of this function(s), i.e. node(s).
:param node_names: names of function(s) that are starting points for traversing the graph backwards.
:return: list of "variables" (i.e. nodes), inclusive of the function names, that are upstream of the passed
in function names.
"""
upstream_nodes, _ = self.graph.get_upstream_nodes(list(node_names))
return [Variable.from_node(n) for n in upstream_nodes]
[docs]
@capture_function_usage
def what_is_the_path_between(
self, upstream_node_name: str, downstream_node_name: str
) -> list[Variable]:
"""Tells you what nodes are on the path between two nodes.
Note: this is inclusive of the two nodes, and returns an unsorted list of nodes.
:param upstream_node_name: the name of the node that we want to start from.
:param downstream_node_name: the name of the node that we want to end at.
:return: Nodes representing the path between the two nodes, inclusive of the two nodes, unsorted.
Returns empty list if no path exists.
:raise ValueError: if the upstream or downstream node name is not in the graph.
"""
all_variables = {n.name: n for n in self.graph.get_nodes()}
# ensure that the nodes exist
if upstream_node_name not in all_variables:
raise ValueError(
f"Upstream node {upstream_node_name} not found in graph." # noqa: E713
)
if downstream_node_name not in all_variables:
raise ValueError(
f"Downstream node {downstream_node_name} not found in graph." # noqa: E713
)
nodes_for_path = self._get_nodes_between(upstream_node_name, downstream_node_name)
return [Variable.from_node(n) for n in nodes_for_path]
def _get_nodes_between(
self, upstream_node_name: str, downstream_node_name: str
) -> set[node.Node]:
"""Gets the nodes representing the path between two nodes, inclusive of the two nodes.
Assumes that the nodes exist in the graph.
:param upstream_node_name: the name of the node that we want to start from.
:param downstream_node_name: the name of the node that we want to end at.
:return: set of nodes that comprise the path between the two nodes, inclusive of the two nodes.
"""
downstream_nodes = self.graph.get_downstream_nodes([upstream_node_name])
# we skip user_nodes because it'll be the upstream node, or it wont matter.
upstream_nodes, _ = self.graph.get_upstream_nodes([downstream_node_name])
nodes_for_path = set(downstream_nodes).intersection(set(upstream_nodes))
return nodes_for_path
[docs]
@capture_function_usage
def visualize_path_between(
self,
upstream_node_name: str,
downstream_node_name: str,
output_file_path: str | None = None,
render_kwargs: dict = None,
graphviz_kwargs: dict = None,
strict_path_visualization: bool = False,
show_legend: bool = True,
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes the path between two nodes.
This is useful for debugging and understanding the path between two nodes.
:param upstream_node_name: the name of the node that we want to start from.
:param downstream_node_name: the name of the node that we want to end at.
:param output_file_path: the full URI of path + file name to save the dot file to.
E.g. 'some/path/graph.dot'. Pass in None to skip saving any file.
:param render_kwargs: a dictionary of values we'll pass to graphviz render function. Defaults to viewing.
If you do not want to view the file, pass in `{'view':False}`.
:param graphviz_kwargs: Kwargs to be passed to the graphviz graph object to configure it.
E.g. dict(graph_attr={'ratio': '1'}) will set the aspect ratio to be equal of the produced image.
:param strict_path_visualization: If True, only the nodes in the path will be visualized. If False, the
nodes in the path and their dependencies, i.e. parents, will be visualized.
:param show_legend: If True, add a legend to the visualization based on the DAG's nodes.
:param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL.
`orient` will be overwridden by the value of `graphviz_kwargs['graph_attr']['rankdir']`
see (https://graphviz.org/docs/attr-types/rankdir/)
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:return: graphviz object.
:param custom_style_function: Optional. Custom style function.
:param keep_dot: If true, produce a DOT file (ref: https://graphviz.org/doc/info/lang.html)
:raise ValueError: if the upstream or downstream node names are not found in the graph,
or there is no path between them.
"""
if render_kwargs is None:
render_kwargs = {}
if graphviz_kwargs is None:
graphviz_kwargs = {}
all_variables = {n.name: n for n in self.graph.get_nodes()}
# ensure that the nodes exist
if upstream_node_name not in all_variables:
raise ValueError(
f"Upstream node {upstream_node_name} not found in graph." # noqa: E713
) # noqa: E713
if downstream_node_name not in all_variables:
raise ValueError(
f"Downstream node {downstream_node_name} not found in graph." # noqa: E713
) # noqa: E713
# set whether the node is user input
node_modifiers = {}
for n in self.graph.get_nodes():
if n.user_defined:
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}
# create nodes that constitute the path
nodes_for_path = self._get_nodes_between(upstream_node_name, downstream_node_name)
if len(nodes_for_path) == 0:
raise ValueError(
f"No path found between {upstream_node_name} and {downstream_node_name}."
)
# add is path for node_modifier's dict
for n in nodes_for_path:
if n.name not in node_modifiers:
node_modifiers[n.name] = set()
node_modifiers[n.name].add(graph.VisualizationNodeModifiers.IS_PATH)
nodes_to_display = set()
for n in nodes_for_path:
nodes_to_display.add(n)
if strict_path_visualization is False:
for d in n.dependencies:
nodes_to_display.add(d)
try:
return self.graph.display(
nodes_to_display,
output_file_path,
render_kwargs=render_kwargs,
graphviz_kwargs=graphviz_kwargs,
node_modifiers=node_modifiers,
strictly_display_only_passed_in_nodes=True,
show_legend=show_legend,
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=self.graph._config,
keep_dot=keep_dot,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
def _process_materializers(
self, materializers: typing.Sequence[MaterializerFactory | ExtractorFactory]
) -> tuple[list[MaterializerFactory], list[ExtractorFactory]]:
"""Processes materializers, splitting them into materializers and extractors.
Note that this also sanitizes the variable names in the materializer dependencies,
so one can pass in functions instead of strings.
:param materializers: Materializers to process
:return: Tuple of materializers and extractors
"""
module_set = {_module.__name__ for _module in self.graph_modules}
materializer_factories = [
m.sanitize_dependencies(module_set)
for m in materializers
if isinstance(m, MaterializerFactory)
]
extractor_factories = [m for m in materializers if isinstance(m, ExtractorFactory)]
return materializer_factories, extractor_factories
[docs]
@capture_function_usage
def materialize(
self,
*materializers: materialization.MaterializerFactory | materialization.ExtractorFactory,
additional_vars: list[str | Callable | Variable] = None,
overrides: dict[str, Any] = None,
inputs: dict[str, Any] = None,
) -> tuple[Any, dict[str, Any]]:
"""Executes and materializes with ad-hoc materializers (`to`) and extractors (`from_`).This does the following:
1. Creates a new graph, appending the desired materialization nodes and prepending the desired extraction nodes
2. Runs the portion of the DAG upstream of the materialization nodes outputted, as well as any additional nodes requested (which can be empty)
3. Returns a Tuple[Materialization metadata, additional vars result]
For instance, say you want to load data, process it, then materialize the output of a node to CSV:
.. code-block:: python
from hamilton import driver, base
from hamilton.io.materialization import to
dr = driver.Driver(my_module, {})
# foo, bar are pd.Series
metadata, result = dr.materialize(
from_.csv(
target="input_data",
path="./input.csv"
),
to.csv(
path="./output.csv"
id="foo_bar_csv",
dependencies=["foo", "bar"],
combine=base.PandasDataFrameResult()
),
additional_vars=["foo", "bar"]
)
The code above will do the following:
1. Load the CSV at "./input.csv" and inject it into he DAG as input_data
2. Run the nodes in the DAG on which "foo" and "bar" depend
3. Materialize the dataframe with "foo" and "bar" as columns, saving it as a CSV file at "./output.csv". The metadata will contain any additional relevant information, and result will be a dictionary with the keys "foo" and "bar" containing the original data.
Note that we pass in a `ResultBuilder` as the `combine` argument to `to`, as we may be materializing
several nodes. This is not relevant in `from_` as we are only loading one dataset.
additional_vars is used for debugging -- E.G. if you want to both realize side-effects and
return an output for inspection. If left out, it will return an empty dictionary.
You can bypass the `combine` keyword for `to` if only one output is required. In this circumstance
"combining/joining" isn't required, e.g. you do that yourself in a function and/or the output of the function
can be directly used. In the case below the output can be turned in to a CSV.
.. code-block:: python
from hamilton import driver, base
from hamilton.io.materialization import to
dr = driver.Driver(my_module, {})
# foo, bar are pd.Series
metadata, _ = dr.materialize(
from_.csv(
target="input_data",
path="./input.csv"
),
to.csv(
path="./output.csv"
id="foo_bar_csv",
dependencies=["foo_bar_already_joined],
),
)
This will just save it to a csv.
Note that materializers can be any valid DataSaver -- these have an isomorphic relationship
with the `@save_to` decorator, which means that any key utilizable in `save_to` can be used
in a materializer. The constructor arguments for a materializer are the same as the
arguments for `@save_to`, with an additional trick -- instead of requiring
everything to be a `source` or `value`, you can pass in a literal, and it will be interpreted
as a value.
That said, if you want to parameterize your materializer based on input or some node in the
DAG, you can easily do that as well:
.. code-block:: python
from hamilton import driver, base
from hamilton.function_modifiers import source
from hamilton.io.materialization import to
dr = driver.Driver(my_module, {})
# foo, bar are pd.Series
metadata, result = dr.Materialize(
from_.csv(
target="input_data",
path=source("load_path")
),
to.csv(
path=source("save_path"),
id="foo_bar_csv",
dependencies=["foo", "bar"],
combine=base.PandasDataFrameResult(),
),
additional_vars=["foo", "bar"],
inputs={"save_path": "./output.csv"},
)
While this is a contrived example, you could imagine something more powerful. Say, for
instance, say you have created and registered a custom `model_registry` materializer that
applies to an argument of your model class, and requires `training_data` to infer the
signature. You could call it like this:
.. code-block:: python
from hamilton import driver, base
from hamilton.function_modifiers import source
from hamilton.io.materialization import to
dr = driver.Driver(my_module, {})
metadata, _ = dr.Materialize(
to.model_registry(
training_data=source("training_data"),
id="foo_model_registry",
tags={"run_id" : ..., "training_date" : ..., ...},
dependencies=["foo_model"]
),
)
In this case, we bypass a result builder (as there's only one model), the single
node we depend on gets saved, and we pass in the training data as an input so the
materializer can infer the signature.
You could also imagine a driver that loads up a model, runs inference, then saves the result:
.. code-block:: python
from hamilton import driver, base
from hamilton.function_modifiers import source
from hamilton.io.materialization import to
dr = driver.Driver(my_module, {})
metadata, _ = dr.Materialize(
from_.model_registry(
target="input_model",
query_tags={
"training_date": ...,
model_version: ...,
}, # query based on run_id, model_version
),
to.csv(
path=source("save_path"),
id="save_inference_data",
dependencies=["inference_data"],
),
)
Note that the "from" extractor has an interesting property -- it effectively functions as overrides. This
means that it can *replace* nodes within a DAG, short-circuiting their behavior. Similar to passing overrides, but they
are dynamically computed with the DAG, rather than statically included from the beginning.
This is customizable through a few APIs:
1. Custom data savers ( :doc:`/concepts/function-modifiers`)
2. Custom result builders
3. Custom data loaders ( :doc:`/concepts/function-modifiers`)
If you find yourself writing these, please consider contributing back! We would love
to round out the set of available materialization tools.
:param materializers: Materializer/extractors to use, created with to.xyz or `from.xyz`
:param additional_vars: Additional variables to return from the graph
:param overrides: Overrides to pass to execution
:param inputs: Inputs to pass to execution
:return: Tuple[Materialization metadata|data, additional_vars result]
"""
if additional_vars is None:
additional_vars = []
run_successful = True
error_execution = None
run_id = str(uuid.uuid4())
outputs = (None, None)
final_vars = self._create_final_vars(additional_vars)
# This is so the finally logging statement does not accidentally die
materializer_vars = []
try:
materializer_factories, extractor_factories = self._process_materializers(materializers)
if len(materializer_factories) == len(final_vars) == 0:
raise ValueError(
"No output requested. Please either pass in materializers that will save data, or pass in `additional_vars` to compute."
)
function_graph = materialization.modify_graph(
self.graph, materializer_factories, extractor_factories
)
Driver._perform_graph_validations(self.adapter, function_graph, self.graph_modules)
if self.adapter.does_hook("post_graph_construct", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"post_graph_construct",
graph=function_graph,
modules=self.graph_modules,
config=function_graph.config,
)
# need to validate the right inputs has been provided.
# we do this on the modified graph.
# Note we will not run the loaders if they're not upstream of the
# materializers or additional_vars
materializer_vars = [m.id for m in materializer_factories]
if self.adapter.does_hook("pre_graph_execute", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"pre_graph_execute",
run_id=run_id,
graph=function_graph,
final_vars=final_vars + materializer_vars,
inputs=inputs,
overrides=overrides,
)
nodes, user_nodes = function_graph.get_upstream_nodes(
final_vars + materializer_vars, inputs, overrides
)
Driver.validate_inputs(function_graph, self.adapter, user_nodes, inputs, nodes)
all_nodes = nodes | user_nodes
self.graph_executor.validate(list(all_nodes))
raw_results = self.__raw_execute(
final_vars=final_vars + materializer_vars,
inputs=inputs,
overrides=overrides,
_fn_graph=function_graph,
_run_id=run_id,
)
materialization_output = {key: raw_results[key] for key in materializer_vars}
raw_results_output = {key: raw_results[key] for key in final_vars}
outputs = materialization_output, raw_results_output
except Exception as e:
run_successful = False
logger.error(SLACK_ERROR_MESSAGE)
error_execution = e
raise e
finally:
if self.adapter.does_hook("post_graph_execute", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"post_graph_execute",
run_id=run_id,
graph=function_graph,
success=run_successful,
error=error_execution,
results=outputs[1],
)
return outputs
[docs]
@capture_function_usage
def visualize_materialization(
self,
*materializers: MaterializerFactory | ExtractorFactory,
output_file_path: str = None,
render_kwargs: dict = None,
additional_vars: list[str | Callable | Variable] = None,
inputs: dict[str, Any] = None,
graphviz_kwargs: dict = None,
overrides: dict[str, Any] = None,
show_legend: bool = True,
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
custom_style_function: Callable = None,
bypass_validation: bool = False,
keep_dot: bool = False,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes materialization. This helps give you a sense of how materialization
will impact the DAG.
:param materializers: Materializers/Extractors to use, see the materialize() function
:param additional_vars: Additional variables to compute (in addition to materializers)
:param output_file_path: Path to output file. Optional. Skip if in a Jupyter Notebook.
:param render_kwargs: Arguments to pass to render. Optional.
:param inputs: Inputs to pass to execution. Optional.
:param graphviz_kwargs: Arguments to pass to graphviz. Optional.
:param overrides: Overrides to pass to execution. Optional.
:param show_legend: If True, add a legend to the visualization based on the DAG's nodes.
:param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL.
`orient` will be overwridden by the value of `graphviz_kwargs['graph_attr']['rankdir']`
see (https://graphviz.org/docs/attr-types/rankdir/)
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, show the schema of the materialized nodes
if nodes have schema metadata attached.
:param custom_style_function: Optional. Custom style function.
:param bypass_validation: If True, bypass validation. Optional.
:return: The graphviz graph, if you want to do something with it
"""
if additional_vars is None:
additional_vars = []
materializer_factories, extractor_factories = self._process_materializers(materializers)
function_graph = materialization.modify_graph(
self.graph, materializer_factories, extractor_factories
)
_final_vars = self._create_final_vars(additional_vars) + [
materializer.id for materializer in materializer_factories
]
return Driver._visualize_execution_helper(
function_graph,
self.adapter,
_final_vars,
output_file_path,
render_kwargs,
inputs,
graphviz_kwargs,
overrides,
show_legend=show_legend,
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
show_schema=show_schema,
custom_style_function=custom_style_function,
bypass_validation=bypass_validation,
keep_dot=keep_dot,
)
[docs]
def validate_execution(
self,
final_vars: list[str | Callable | Variable],
overrides: dict[str, Any] = None,
inputs: dict[str, Any] = None,
):
"""Validates execution of the graph. One can call this to validate execution, independently of actually executing.
Note this has no return -- it will raise a ValueError if there is an issue.
:param final_vars: Final variables to compute
:param overrides: Overrides to pass to execution.
:param inputs: Inputs to pass to execution.
:raise ValueError: if any issues with executino can be detected.
"""
nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs, overrides)
Driver.validate_inputs(self.graph, self.adapter, user_nodes, inputs, nodes)
self.graph_executor.validate(list(nodes | user_nodes))
[docs]
def validate_materialization(
self,
*materializers: materialization.MaterializerFactory,
additional_vars: list[str | Callable | Variable] = None,
overrides: dict[str, Any] = None,
inputs: dict[str, Any] = None,
):
"""Validates materialization of the graph. Effectively .materialize() with a dry-run.
Note this has no return -- it will raise a ValueError if there is an issue.
:param materializers: Materializers to use, see the materialize() function
:param additional_vars: Additional variables to compute (in addition to materializers)
:param overrides: Overrides to pass to execution. Optional.
:param inputs: Inputs to pass to execution. Optional.
:raise ValueError: if any issues with materialization can be detected.
"""
if additional_vars is None:
additional_vars = []
final_vars = self._create_final_vars(additional_vars)
module_set = {_module.__name__ for _module in self.graph_modules}
materializer_factories, extractor_factories = self._process_materializers(materializers)
materializer_factories = [
m.sanitize_dependencies(module_set) for m in materializer_factories
]
materializer_vars = [m.id for m in materializer_factories]
function_graph = materialization.modify_graph(
self.graph, materializer_factories, extractor_factories
)
# need to validate the right inputs has been provided.
# we do this on the modified graph.
nodes, user_nodes = function_graph.get_upstream_nodes(
final_vars + materializer_vars, inputs, overrides
)
Driver.validate_inputs(function_graph, self.adapter, user_nodes, inputs, nodes)
all_nodes = nodes | user_nodes
self.graph_executor.validate(list(all_nodes))
@property
def cache(self) -> HamiltonCacheAdapter:
"""Directly access the cache adapter"""
if self.adapter:
for adapter in self.adapter.adapters:
if isinstance(adapter, HamiltonCacheAdapter):
return adapter
else:
raise KeyError(
"Cache not yet set. Add a cache by using ``Builder().with_cache()`` when building the ``Driver``."
)
[docs]
class Builder:
[docs]
def __init__(self):
"""Constructs a driver builder. No parameters as you call methods to set fields."""
# Toggling versions
self.v2_executor = False
# common fields
self.config = {}
self.modules = []
self.materializers = []
# Allow later modules to override nodes of the same name
self._allow_module_overrides = False
self.legacy_graph_adapter = None
# Standard execution fields
self.adapters: list[lifecycle_base.LifecycleAdapter] = []
# Dynamic execution fields
self.execution_manager = None
self.local_executor = None
self.remote_executor = None
self.grouping_strategy = None
def _require_v2(self, message: str):
if not self.v2_executor:
raise ValueError(message)
def _require_field_unset(self, field: str, message: str, unset_value: Any = None):
if getattr(self, field) != unset_value:
raise ValueError(message)
def _require_field_set(self, field: str, message: str, unset_value: Any = None):
if getattr(self, field) == unset_value:
raise ValueError(message)
[docs]
def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> Self:
"""Enables the Parallelizable[] type, which in turn enables:
1. Grouped execution into tasks
2. Parallel execution
:return: self
"""
if not allow_experimental_mode:
raise ValueError(
"Remote execution is currently experimental. "
"Please set allow_experiemental_mode=True to enable it."
)
self.v2_executor = True
return self
[docs]
def with_config(self, config: dict[str, Any]) -> Self:
"""Adds the specified configuration to the config.
This can be called multilple times -- later calls will take precedence.
:param config: Config to use.
:return: self
"""
self.config.update(config)
return self
[docs]
def with_modules(self, *modules: ModuleType) -> Self:
"""Adds the specified modules to the modules list.
This can be called multiple times.
:param modules: Modules to use.
:return: self
"""
self.modules.extend(modules)
return self
[docs]
def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> Self:
"""Sets the adapter to use.
:param adapter: Adapter to use.
:return: self
"""
self._require_field_unset("legacy_graph_adapter", "Cannot set adapter twice.")
self.legacy_graph_adapter = adapter
return self
[docs]
def with_adapters(self, *adapters: lifecycle_base.LifecycleAdapter) -> Self:
"""Sets the adapter to use.
:param adapter: Adapter to use.
:return: self
"""
if any(isinstance(adapter, HamiltonCacheAdapter) for adapter in adapters):
self._require_field_unset(
"cache", "Cannot use `.with_cache()` or with `.with_adapters(SmartCacheAdapter())`."
)
self.adapters.extend(adapters)
return self
[docs]
def with_data_quality_disabled(self) -> Self:
"""Disables all ``@check_output`` / ``@check_output_custom`` validators at graph-construction
time. No validator nodes are created, so there is zero runtime cost.
This is equivalent to ``.with_config({DISABLE_DATA_QUALITY_CHECKS_CONFIG_KEY: True})``.
Note that a subsequent ``.with_config({DISABLE_DATA_QUALITY_CHECKS_CONFIG_KEY: False})``
will re-enable validation, since ``with_config`` always wins on the last write.
:return: self
"""
return self.with_config({DISABLE_DATA_QUALITY_CHECKS_CONFIG_KEY: True})
[docs]
def with_materializers(self, *materializers: ExtractorFactory | MaterializerFactory) -> Self:
"""Add materializer nodes to the `Driver`
The generated nodes can be referenced by name in `.execute()`
:param materializers: materializers to add to the dataflow
:return: self
"""
if any(
m for m in materializers if not isinstance(m, (ExtractorFactory, MaterializerFactory))
):
if len(materializers) == 1 and isinstance(materializers[0], Sequence):
raise ValueError(
"`.with_materializers()` received a sequence. Unpack it by prepending `*` e.g., `*[to.json(...), from_.parquet(...)]`"
)
else:
raise ValueError(
f"`.with_materializers()` only accepts materializers. Received instead: {materializers}"
)
self.materializers.extend(materializers)
return self
[docs]
def with_cache(
self,
path: str | pathlib.Path = ".hamilton_cache",
metadata_store: MetadataStore | None = None,
result_store: ResultStore | None = None,
default: Literal[True] | Sequence[str] | None = None,
recompute: Literal[True] | Sequence[str] | None = None,
ignore: Literal[True] | Sequence[str] | None = None,
disable: Literal[True] | Sequence[str] | None = None,
default_behavior: Literal["default", "recompute", "disable", "ignore"] = "default",
default_loader_behavior: Literal["default", "recompute", "disable", "ignore"] = "default",
default_saver_behavior: Literal["default", "recompute", "disable", "ignore"] = "default",
log_to_file: bool = False,
) -> Self:
"""Add the caching adapter to the `Driver`
:param path: path where the cache metadata and results will be stored
:param metadata_store: BaseStore handling metadata for the cache adapter
:param result_store: BaseStore caching dataflow execution results
:param default: Set caching behavior to DEFAULT for specified node names. If True, apply to all nodes.
:param recompute: Set caching behavior to RECOMPUTE for specified node names. If True, apply to all nodes.
:param ignore: Set caching behavior to IGNORE for specified node names. If True, apply to all nodes.
:param disable: Set caching behavior to DISABLE for specified node names. If True, apply to all nodes.
:param default_behavior: Set the default caching behavior.
:param default_loader_behavior: Set the default caching behavior `DataLoader` nodes.
:param default_saver_behavior: Set the default caching behavior `DataSaver` nodes.
:log_to_file: If True, the cache adapter logs will be stored in JSONL format under the metadata_store directory
:return: self
Learn more on the :doc:`/concepts/caching` Concepts page.
.. code-block:: python
from hamilton import driver
import my_dataflow
dr = (
driver.Builder()
.with_module(my_dataflow)
.with_cache()
.build()
)
# execute twice
dr.execute([...])
dr.execute([...])
# view cache logs
dr.cache.logs()
"""
self._require_field_unset(
"cache", "Cannot use `.with_cache()` or with `.with_adapters(SmartCacheAdapter())`."
)
adapter = HamiltonCacheAdapter(
path=path,
metadata_store=metadata_store,
result_store=result_store,
default=default,
recompute=recompute,
ignore=ignore,
disable=disable,
default_behavior=default_behavior,
default_loader_behavior=default_loader_behavior,
default_saver_behavior=default_saver_behavior,
log_to_file=log_to_file,
)
self.adapters.append(adapter)
return self
@property
def cache(self) -> HamiltonCacheAdapter | None:
"""Attribute to check if a cache was set, either via `.with_cache()` or
`.with_adapters(SmartCacheAdapter())`
Required for the check `._require_field_unset()`
"""
if self.adapters:
for adapter in self.adapters:
if isinstance(adapter, HamiltonCacheAdapter):
return adapter
[docs]
def with_execution_manager(self, execution_manager: executors.ExecutionManager) -> Self:
"""Sets the execution manager to use. Note that this cannot be used if local_executor
or remote_executor are also set
:param execution_manager:
:return: self
"""
self._require_v2("Cannot set execution manager without first enabling the V2 Driver")
self._require_field_unset("execution_manager", "Cannot set execution manager twice")
self._require_field_unset(
"remote_executor",
"Cannot set execution manager with remote executor set -- these are disjoint",
)
self.execution_manager = execution_manager
return self
[docs]
def with_remote_executor(self, remote_executor: executors.TaskExecutor) -> Self:
"""Sets the execution manager to use. Note that this cannot be used if local_executor
or remote_executor are also set
:param remote_executor: Remote executor to use
:return: self
"""
self._require_v2("Cannot set execution manager without first enabling the V2 Driver")
self._require_field_unset("remote_executor", "Cannot set remote executor twice")
self._require_field_unset(
"execution_manager",
"Cannot set remote executor with execution manager set -- these are disjoint",
)
self.remote_executor = remote_executor
return self
[docs]
def with_local_executor(self, local_executor: executors.TaskExecutor) -> Self:
"""Sets the execution manager to use. Note that this cannot be used if local_executor
or remote_executor are also set
:param local_executor: Local executor to use
:return: self
"""
self._require_v2("Cannot set execution manager without first enabling the V2 Driver")
self._require_field_unset("local_executor", "Cannot set local executor twice")
self._require_field_unset(
"execution_manager",
"Cannot set local executor with execution manager set -- these are disjoint",
)
self.local_executor = local_executor
return self
[docs]
def with_grouping_strategy(self, grouping_strategy: grouping.GroupingStrategy) -> Self:
"""Sets a node grouper, which tells the driver how to group nodes into tasks for execution.
:param node_grouper: Node grouper to use.
:return: self
"""
self._require_v2("Cannot set grouping strategy without first enabling the V2 Driver")
self._require_field_unset("grouping_strategy", "Cannot set grouping strategy twice")
self.grouping_strategy = grouping_strategy
return self
[docs]
def allow_module_overrides(self) -> Self:
"""Same named functions in different modules get overwritten.
If multiple modules have same named functions, the later module overrides the previous one(s).
The order of listing the modules is important, since later ones will overwrite the previous ones. This is a global call affecting all imported modules.
See https://github.com/apache/hamilton/tree/main/examples/module_overrides for more info.
:return: self
"""
self._allow_module_overrides = True
return self
[docs]
def build(self) -> Driver:
"""Builds the driver -- note that this can return a different class, so you'll likely
want to have a sense of what it returns.
Note: this defaults to a dictionary adapter if no adapter is set.
:return: The driver you specified.
"""
adapter = self.adapters if self.adapters is not None else []
if self.legacy_graph_adapter is not None:
adapter.append(self.legacy_graph_adapter)
graph_executor = None
if self.v2_executor:
execution_manager = self.execution_manager
if execution_manager is None:
local_executor = self.local_executor or executors.SynchronousLocalTaskExecutor()
remote_executor = self.remote_executor or executors.MultiThreadingExecutor(
max_tasks=10
)
execution_manager = executors.DefaultExecutionManager(
local_executor=local_executor, remote_executor=remote_executor
)
grouping_strategy = self.grouping_strategy or grouping.GroupByRepeatableBlocks()
graph_executor = TaskBasedGraphExecutor(
execution_manager=execution_manager,
grouping_strategy=grouping_strategy,
adapter=lifecycle_base.LifecycleAdapterSet(*adapter),
)
return Driver(
self.config,
*self.modules,
adapter=adapter,
_materializers=self.materializers,
_graph_executor=graph_executor,
_use_legacy_adapter=False,
allow_module_overrides=self._allow_module_overrides,
)
[docs]
def copy(self) -> "Builder":
"""Creates a copy of the current state of this Builder.
NOTE. The copied Builder currently holds reference of Builder attributes
"""
new_builder = Builder()
new_builder.v2_executor = self.v2_executor
new_builder.config = self.config.copy()
new_builder.modules = self.modules.copy()
new_builder.legacy_graph_adapter = self.legacy_graph_adapter
new_builder.adapters = self.adapters.copy()
new_builder.materializers = self.materializers.copy()
new_builder.execution_manager = self.execution_manager
new_builder.local_executor = self.local_executor
new_builder.remote_executor = self.remote_executor
new_builder.grouping_strategy = self.grouping_strategy
return new_builder
if __name__ == "__main__":
"""some example test code"""
import importlib
formatter = logging.Formatter("[%(levelname)s] %(asctime)s %(name)s(%(lineno)s): %(message)s")
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.setLevel(logging.INFO)
if len(sys.argv) < 2:
logger.error("No modules passed")
sys.exit(1)
logger.info(f"Importing {sys.argv[1]}")
module = importlib.import_module(sys.argv[1])
x = pd.date_range("2019-01-05", "2020-12-31", freq="7D")
x.index = x
dr = Driver(
{
"VERSION": "kids",
"as_of": datetime.strptime("2019-06-01", "%Y-%m-%d"),
"end_date": "2020-12-31",
"start_date": "2019-01-05",
"start_date_d": datetime.strptime("2019-01-05", "%Y-%m-%d"),
"end_date_d": datetime.strptime("2020-12-31", "%Y-%m-%d"),
"segment_filters": {"business_line": "womens"},
},
module,
)
df = dr.execute(
["date_index", "some_column"],
# ,overrides={'DATE': pd.Series(0)}
display_graph=False,
)
print(df)