# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import abc
from abc import ABC
from collections.abc import Collection
from types import ModuleType
from typing import TYPE_CHECKING, Any, final
from hamilton import graph_types, node
# This is only here for a type-hint
# As python types aren't real (they're determined at runtime), we can't have circular import resolved
# These are often necessary to handle typing -- as types don't have a perfect DAG of dependencies
# In this case, we're breaking the following loop:
# -> lifecycle_api depends on graph_types and FunctionGraph
# -> graph_types depends on hamilton.base
# -> hamilton.base depends on lifecycle_api, as some interfaces for graph adapters live there
# To really fix this we should move everything user-facing out of base, which is a pretty sloppy name for a package anyway
# And put it where it belongs. For now we're OK with the TYPE_CHECKING hack
if TYPE_CHECKING:
from hamilton.execution.grouping import NodeGroupPurpose
from hamilton.graph import FunctionGraph
else:
NodeGroupPurpose = None
from hamilton.graph_types import HamiltonGraph, HamiltonNode
from hamilton.lifecycle.base import (
BaseDoBuildResult,
BaseDoCheckEdgeTypesMatch,
BaseDoNodeExecute,
BaseDoValidateInput,
BasePostGraphConstruct,
BasePostGraphExecute,
BasePostNodeExecute,
BasePostTaskExecute,
BasePostTaskExpand,
BasePostTaskGroup,
BasePostTaskReturn,
BasePreGraphExecute,
BasePreNodeExecute,
BasePreTaskExecute,
BasePreTaskSubmission,
BaseValidateGraph,
BaseValidateNode,
)
try:
from typing import override
except ImportError:
override = lambda x: x # noqa E731
[docs]
class ResultBuilder(BaseDoBuildResult, abc.ABC):
"""Abstract class for building results. All result builders should inherit from this class and implement the build_result function.
Note that applicable_input_type and output_type are optional, but recommended, for backwards
compatibility. They let us type-check this. They will default to Any, which means that they'll
connect to anything."""
[docs]
@abc.abstractmethod
def build_result(self, **outputs: Any) -> Any:
"""Given a set of outputs, build the result.
:param outputs: the outputs from the execution of the graph.
:return: the result of the execution of the graph.
"""
pass
[docs]
@override
@final
def do_build_result(self, outputs: dict[str, Any]) -> Any:
"""Implements the do_build_result method from the BaseDoBuildResult class.
This is kept from the user as the public-facing API is build_result, allowing us to change the
API/implementation of the internal set of hooks"""
return self.build_result(**outputs)
[docs]
def output_type(self) -> type:
"""Returns the output type of this result builder
:return: the type that this creates
"""
return Any
[docs]
class LegacyResultMixin(ResultBuilder, ABC):
"""Backwards compatible legacy result builder. This utilizes a static method as we used to do that,
although often times they got confused. If you want a result builder, use ResultBuilder above instead.
"""
[docs]
@staticmethod
def build_result(**outputs: Any) -> Any:
"""Given a set of outputs, build the result.
:param outputs: the outputs from the execution of the graph.
:return: the result of the execution of the graph.
"""
pass
[docs]
class GraphAdapter(
BaseDoNodeExecute,
LegacyResultMixin,
BaseDoValidateInput,
BaseDoCheckEdgeTypesMatch,
abc.ABC,
):
"""This is an implementation of HamiltonGraphAdapter, which has now been
implemented with lifecycle methods/hooks."""
[docs]
@staticmethod
@abc.abstractmethod
def check_node_type_equivalence(node_type: type, input_type: type) -> bool:
"""Used to check whether two types are equivalent.
Static, purely for legacy reasons.
This is used when the function graph is being created and we're statically type checking the annotations
for compatibility.
:param node_type: The type of the node.
:param input_type: The type of the input that would flow into the node.
:return: True if the types are equivalent, False otherwise.
"""
pass
[docs]
@override
@final
def do_node_execute(
self, run_id: str, node_: node.Node, kwargs: dict[str, Any], task_id: str | None = None
) -> Any:
return self.execute_node(node_, kwargs)
[docs]
@override
@final
def do_check_edge_types_match(self, type_from: type, type_to: type) -> bool:
return self.check_node_type_equivalence(type_to, type_from)
[docs]
@abc.abstractmethod
def execute_node(self, node: node.Node, kwargs: dict[str, Any]) -> Any:
"""Given a node that represents a hamilton function, execute it.
Note, in some adapters this might just return some type of "future".
:param node: the Hamilton Node
:param kwargs: the kwargs required to exercise the node function.
:return: the result of exercising the node.
"""
pass
[docs]
class NodeExecutionHook(BasePreNodeExecute, BasePostNodeExecute, abc.ABC):
"""Implement this to hook into the node execution lifecycle. You can call anything before and after the driver"""
[docs]
@abc.abstractmethod
def run_before_node_execution(
self,
*,
node_name: str,
node_tags: dict[str, Any],
node_kwargs: dict[str, Any],
node_return_type: type,
task_id: str | None,
run_id: str,
node_input_types: dict[str, Any],
**future_kwargs: Any,
):
"""Hook that is executed prior to node execution.
:param node_name: Name of the node.
:param node_tags: Tags of the node
:param node_kwargs: Keyword arguments to pass to the node
:param node_return_type: Return type of the node
:param task_id: The ID of the task, none if not in a task-based environment
:param run_id: Run ID (unique in process scope) of the current run. Use this to track state.
:param node_input_types: the input types to the node and what it is expecting
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility
"""
pass
[docs]
@override
@final
def pre_node_execute(
self,
*,
run_id: str,
node_: node.Node,
kwargs: dict[str, Any],
task_id: str | None = None,
):
"""Wraps the before_execution method, providing a bridge to an external-facing API. Do not override this!"""
self.run_before_node_execution(
node_name=node_.name,
node_tags=node_.tags,
node_kwargs=kwargs,
node_return_type=node_.type,
task_id=task_id,
run_id=run_id,
node_input_types={k: v[0] for k, v in node_.input_types.items()},
)
[docs]
@abc.abstractmethod
def run_after_node_execution(
self,
*,
node_name: str,
node_tags: dict[str, Any],
node_kwargs: dict[str, Any],
node_return_type: type,
result: Any,
error: Exception | None,
success: bool,
task_id: str | None,
run_id: str,
**future_kwargs: Any,
):
"""Hook that is executed post node execution.
:param node_name: Name of the node in question
:param node_tags: Tags of the node
:param node_kwargs: Keyword arguments passed to the node
:param node_return_type: Return type of the node
:param result: Output of the node, None if an error occurred
:param error: Error that occurred, None if no error occurred
:param success: Whether the node executed successfully
:param task_id: The ID of the task, none if not in a task-based environment
:param run_id: Run ID (unique in process scope) of the current run. Use this to track state.
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility
"""
pass
[docs]
@override
@final
def post_node_execute(
self,
*,
run_id: str,
node_: node.Node,
kwargs: dict[str, Any],
success: bool,
error: Exception | None,
result: Any | None,
task_id: str | None = None,
):
"""Wraps the after_execution method, providing a bridge to an external-facing API. Do not override this!"""
self.run_after_node_execution(
node_name=node_.name,
node_tags=node_.tags,
node_kwargs=kwargs,
node_return_type=node_.type,
result=result,
error=error,
task_id=task_id,
success=success,
run_id=run_id,
)
[docs]
class GraphExecutionHook(BasePreGraphExecute, BasePostGraphExecute):
"""Implement this to execute code before and after graph execution. This is useful for logging, etc..."""
[docs]
@override
@final
def post_graph_execute(
self,
*,
run_id: str,
graph: "FunctionGraph",
success: bool,
error: Exception | None,
results: dict[str, Any] | None,
):
"""Just delegates to the interface method, passing in the right data."""
return self.run_after_graph_execution(
graph=HamiltonGraph.from_graph(graph),
success=success,
error=error,
results=results,
run_id=run_id,
)
[docs]
@override
@final
def pre_graph_execute(
self,
*,
run_id: str,
graph: "FunctionGraph",
final_vars: list[str],
inputs: dict[str, Any],
overrides: dict[str, Any],
):
"""Implementation of the pre_graph_execute hook. This just converts the inputs to
the format the user-facing hook is expecting -- performing a walk of the DAG to pass in
the set of nodes to execute. Delegates to the interface method."""
all_nodes, user_defined_nodes = graph.get_upstream_nodes(final_vars, inputs, overrides)
nodes_to_execute = set(all_nodes) - set(user_defined_nodes)
return self.run_before_graph_execution(
graph=HamiltonGraph.from_graph(graph),
final_vars=final_vars,
inputs=inputs,
overrides=overrides,
execution_path=[item.name for item in nodes_to_execute],
run_id=run_id,
)
[docs]
@abc.abstractmethod
def run_before_graph_execution(
self,
*,
graph: graph_types.HamiltonGraph,
final_vars: list[str],
inputs: dict[str, Any],
overrides: dict[str, Any],
execution_path: Collection[str],
run_id: str,
**future_kwargs: Any,
):
"""This is run prior to graph execution. This allows you to do anything you want before the graph executes,
knowing the basic information that was passed in.
:param graph: Graph that is being executed
:param final_vars: Output variables of the graph
:param inputs: Input variables passed to the graph
:param overrides: Overrides passed to the graph
:param execution_path: Collection of nodes that will be executed --
these are just the nodes (not input nodes) that will be run during the course of execution.
:param run_id: Run ID (unique in process scope) of the current run. Use this to track state.
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility
"""
pass
[docs]
@abc.abstractmethod
def run_after_graph_execution(
self,
*,
graph: graph_types.HamiltonGraph,
success: bool,
error: Exception | None,
results: dict[str, Any] | None,
run_id: str,
**future_kwargs: Any,
):
"""This is run after graph execution. This allows you to do anything you want after the graph executes,
knowing the results of the execution/any errors.
:param graph: Graph that is being executed
:param results: Results of the graph execution
:param error: Error that occurred, None if no error occurred
:param success: Whether the graph executed successfully
:param run_id: Run ID (unique in process scope) of the current run. Use this to track state.
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility
"""
pass
[docs]
class TaskSubmissionHook(BasePreTaskSubmission, abc.ABC):
"""Implement this to hook into the task submission process. Tasks are submitted to an executor,
which then controls how and where the nodes associated with the task are run."""
[docs]
@override
def pre_task_submission(
self,
*,
run_id: str,
task_id: str,
nodes: list["node.Node"],
inputs: dict[str, Any],
overrides: dict[str, Any],
spawning_task_id: str | None,
purpose: NodeGroupPurpose,
):
self.run_before_task_submission(
run_id=run_id,
task_id=task_id,
nodes=nodes,
inputs=inputs,
overrides=overrides,
spawning_task_id=spawning_task_id,
purpose=purpose,
)
[docs]
@abc.abstractmethod
def run_before_task_submission(
self,
*,
run_id: str,
task_id: str,
nodes: list["node.Node"],
inputs: dict[str, Any],
overrides: dict[str, Any],
spawning_task_id: str | None,
purpose: NodeGroupPurpose,
**future_kwargs,
):
"""Runs prior to a task being submitted to an executor. By definition this is run *outside*
of the task executor, on the process that executed the driver.
:param run_id: ID of the run this is under.
:param task_id: ID of the task we're launching.
:param nodes: Nodes that are part of this task
:param inputs: Inputs to the task
:param overrides: Overrides passed to the task
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
:param future_kwargs: Reserved for backwards compatibility.
"""
pass
[docs]
class TaskReturnHook(BasePostTaskReturn, abc.ABC):
"""Implement this to hook into the task return process. Tasks are submitted to an executor,
which executes the task and returns the results (or raises an error)."""
[docs]
@override
def post_task_return(
self,
*,
run_id: str,
task_id: str,
nodes: list["node.Node"],
result: Any,
success: bool,
error: Exception | None,
spawning_task_id: str | None,
purpose: NodeGroupPurpose,
):
self.run_after_task_return(
run_id=run_id,
task_id=task_id,
nodes=nodes,
result=result,
success=success,
error=error,
spawning_task_id=spawning_task_id,
purpose=purpose,
)
[docs]
@abc.abstractmethod
def run_after_task_return(
self,
*,
run_id: str,
task_id: str,
nodes: list["node.Node"],
result: Any,
success: bool,
error: Exception | None,
spawning_task_id: str | None,
purpose: NodeGroupPurpose,
**future_kwargs,
):
"""Runs after a task has been returned from a executor. By definition this is run *outside*
of the task executor, on the process that executed the driver.
:param run_id: ID of the run this is under.
:param task_id: ID of the task that was just executed.
:param nodes: Nodes that were part of this task
:param result: Result of the task
:param success: Whether the task was successful
:param error: The error the task threw, if any
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
:param future_kwargs: Reserved for backwards compatibility.
"""
pass
[docs]
class TaskExecutionHook(BasePreTaskExecute, BasePostTaskExecute, abc.ABC):
"""Implement this to hook into the task execution process. Tasks consist of a group of one or
more nodes that are run on a task executor."""
[docs]
def pre_task_execute(
self,
*,
run_id: str,
task_id: str,
nodes: list["node.Node"],
inputs: dict[str, Any],
overrides: dict[str, Any],
spawning_task_id: str | None,
purpose: NodeGroupPurpose,
):
self.run_before_task_execution(
run_id=run_id,
task_id=task_id,
nodes=[HamiltonNode.from_node(n) for n in nodes],
inputs=inputs,
overrides=overrides,
spawning_task_id=spawning_task_id,
purpose=purpose,
)
[docs]
def post_task_execute(
self,
*,
run_id: str,
task_id: str,
nodes: list["node.Node"],
results: dict[str, Any] | None,
success: bool,
error: Exception,
spawning_task_id: str | None,
purpose: NodeGroupPurpose,
):
self.run_after_task_execution(
run_id=run_id,
task_id=task_id,
nodes=[HamiltonNode.from_node(n) for n in nodes],
results=results,
success=success,
error=error,
spawning_task_id=spawning_task_id,
purpose=purpose,
)
[docs]
@abc.abstractmethod
def run_before_task_execution(
self,
*,
task_id: str,
run_id: str,
nodes: list[HamiltonNode],
inputs: dict[str, Any],
overrides: dict[str, Any],
spawning_task_id: str | None,
purpose: NodeGroupPurpose,
**future_kwargs,
):
"""Runs prior to any of the nodes associated with a task. By definition this is run *inside*
of the executor and therefore may be run on separate or distributed processes.
:param task_id: ID of the task we're launching.
:param run_id: ID of the run this is under.
:param nodes: Nodes that are part of this task
:param inputs: Inputs to the task
:param overrides: Overrides passed to the task
:param future_kwargs: Reserved for backwards compatibility.
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
"""
pass
[docs]
@abc.abstractmethod
def run_after_task_execution(
self,
*,
task_id: str,
run_id: str,
nodes: list[HamiltonNode],
results: dict[str, Any] | None,
success: bool,
error: Exception,
spawning_task_id: str | None,
purpose: NodeGroupPurpose,
**future_kwargs,
):
"""Runs after all of the nodes associated with a task have been executed. By definition this
is run *inside* of the executor and therefore may be run on separate or distributed processes.
:param task_id: ID of the task that was just executed
:param run_id: ID of the run this was under.
:param nodes: Nodes that were part of this task
:param results: Results of the task, per-node
:param success: Whether the task was successful
:param error: The error the task threw, if any
:param future_kwargs: Reserved for backwards compatibility.
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
"""
pass
[docs]
class EdgeConnectionHook(BaseDoCheckEdgeTypesMatch, BaseDoValidateInput, abc.ABC):
"""Implement this to customize edges that are allowed in the graph. You can do customizations around typing here."""
[docs]
@override
@final
def do_check_edge_types_match(self, *, type_from: type, type_to: type) -> bool:
"""Wraps the check_edge_types_match method, providing a bridge to an external-facing API. Do not override this!"""
return self.check_edge_types_match(type_from, type_to)
[docs]
@abc.abstractmethod
def check_edge_types_match(self, type_from: type, type_to: type, **kwargs: Any) -> bool:
"""This is run to check if edge types match. Note that this is an OR functionality
-- this is run after we do some default checks, so this can only be permissive.
Reach out if you want to be more restrictive than the default checks.
:param type_from: The type of the node that is the source of the edge.
:param type_to: The type of the node that is the destination of the edge.
:param kwargs: This is kept for future backwards compatibility.
:return: Whether or not the two node types form a valid edge.
"""
pass
[docs]
class NodeExecutionMethod(BaseDoNodeExecute):
"""API for executing a node. This takes in tags, callable, node name, and kwargs, and is
responsible for executing the node and returning the result. Note this is not (currently)
able to be layered together, although we may add that soon.
"""
[docs]
@override
@final
def do_node_execute(
self,
*,
run_id: str,
node_: node.Node,
kwargs: dict[str, Any],
task_id: str | None = None,
) -> Any:
return self.run_to_execute_node(
node_name=node_.name,
node_tags=node_.tags,
node_callable=node_.callable,
node_kwargs=kwargs,
task_id=task_id,
is_expand=node_.node_role == node.NodeType.EXPAND,
is_collect=node_.node_role == node.NodeType.COLLECT,
)
[docs]
@abc.abstractmethod
def run_to_execute_node(
self,
*,
node_name: str,
node_tags: dict[str, Any],
node_callable: Any,
node_kwargs: dict[str, Any],
task_id: str | None,
is_expand: bool,
is_collect: bool,
**future_kwargs: Any,
) -> Any:
"""This method is responsible for executing the node and returning the result.
:param node_name: Name of the node.
:param node_tags: Tags of the node.
:param node_callable: Callable of the node.
:param node_kwargs: Keyword arguments to pass to the node.
:param task_id: The ID of the task, none if not in a task-based environment
:param is_expand: Whether the node is parallelizable.
:param is_collect: Whether the node is a collect node.
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility
:return: The result of the node execution -- up to you to return this.
"""
pass
[docs]
class StaticValidator(BaseValidateGraph, BaseValidateNode):
"""Performs static validation of the DAG. Note that this has the option to perform default validation for each method --
this means that if you don't implement one of these it is OK.
.. code-block:: python
class MyTagValidator(api.StaticValidator):
'''Validates tags on a node'''
def run_to_validate_node(
self, *, node: HamiltonNode, **future_kwargs
) -> tuple[bool, Optional[str]]:
if node.tags.get("node_type", "") == "output":
table_name = node.tags.get("table_name")
if not table_name: # None or empty
error_msg = (f"Node {node.tags['module']}.{node.name} "
"is an output node, but does not have a table_name tag.")
return False, error_msg
return True, None
"""
[docs]
def run_to_validate_node(
self, *, node: HamiltonNode, **future_kwargs
) -> tuple[bool, str | None]:
"""Override this to build custom node validations! Defaults to just returning that a node is valid so you don't have to implement it if you want to just implement a single method.
Runs post node construction to validate a node. You have access to a bunch of metadata about the node, stored in the hamilton_node argument
:param node: Node to validate
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility
:return: A tuple of whether the node is valid and an error
message in the case of failure. Return [True, None] for a valid node.Otherwise, return a detailed error message -- this should have all context/debugging information, but does not need to
mention the node name (it will be aggregated with others).
"""
return True, None
[docs]
def run_to_validate_graph(
self, graph: HamiltonGraph, **future_kwargs
) -> tuple[bool, str | None]:
"""Override this to build custom DAG validations! Default to just returning that the graph is valid, so you don't have to implement it if you want to just implement a single method.
Runs post graph construction to validate a graph. You have access to a bunch of metadata about the graph, stored in the graph argument.
:param graph: Graph to validate.
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility
:return: A tuple of whether the graph is valid and an error message in the case of failure. Return [True, None] for a valid graph.
Otherwise, return a detailed error message -- this should have all context/debugging information.
"""
return True, None
[docs]
@override
@final
def validate_node(self, *, created_node: node.Node) -> tuple[bool, Exception | None]:
return self.run_to_validate_node(node=HamiltonNode.from_node(created_node))
[docs]
@override
@final
def validate_graph(
self, *, graph: "FunctionGraph", modules: list[ModuleType], config: dict[str, Any]
) -> tuple[bool, Exception | None]:
return self.run_to_validate_graph(graph=HamiltonGraph.from_graph(graph))
[docs]
class TaskGroupingHook(BasePostTaskGroup, BasePostTaskExpand):
"""Implement this to run something after task grouping or task expansion. This will allow you to
capture information about the tasks during `Parallelize`/`Collect` blocks in dynamic DAG execution."""
[docs]
@override
@final
def post_task_group(self, *, run_id: str, task_ids: list[str]):
return self.run_after_task_grouping(run_id=run_id, task_ids=task_ids)
[docs]
@override
@final
def post_task_expand(self, *, run_id: str, task_id: str, parameters: dict[str, Any]):
return self.run_after_task_expansion(run_id=run_id, task_id=task_id, parameters=parameters)
[docs]
@abc.abstractmethod
def run_after_task_grouping(self, *, run_id: str, task_ids: list[str], **future_kwargs):
"""Runs after task grouping. This allows you to capture information about which tasks were
created for a given run.
:param run_id: ID of the run, unique in scope of the driver.
:param task_ids: List of tasks that were grouped together.
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility.
"""
pass
[docs]
@abc.abstractmethod
def run_after_task_expansion(
self, *, run_id: str, task_id: str, parameters: dict[str, Any], **future_kwargs
):
"""Runs after task expansion in Parallelize/Collect blocks. This allows you to capture information
about the task that was expanded.
:param run_id: ID of the run, unique in scope of the driver.
:param task_id: ID of the task that was expanded.
:param parameters: Parameters that were passed to the task.
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility.
"""
pass
[docs]
class GraphConstructionHook(BasePostGraphConstruct, abc.ABC):
"""Hook that is run after graph construction. This allows you to register/capture info on the graph.
Note that, in the case of materialization, this may be called multiple times (once when we create the graph,
once when we materialize). Currently information into that is not exposed to the user, but we will be adding that in future
iterations.
"""
[docs]
def post_graph_construct(
self, *, graph: "FunctionGraph", modules: list[ModuleType], config: dict[str, Any]
):
self.run_after_graph_construction(graph=HamiltonGraph.from_graph(graph), config=config)
[docs]
@abc.abstractmethod
def run_after_graph_construction(
self, *, graph: HamiltonGraph, config: dict[str, Any], **future_kwargs: Any
):
"""Hook that is run post graph construction. This allows you to register/capture info on the graph.
A common pattern is to store something in your object's state here so that you can use it later
(E.G. compute a hash on the graph)
:param graph: Graph that was constructed
:param config: Configuration used to construct the graph
:param future_kwargs: Reserved for backwards compatibility.
"""
pass