Source code for hamilton.caching.adapter

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import collections
import dataclasses
import enum
import functools
import json
import logging
import pathlib
import uuid
from collections.abc import Callable, Collection
from datetime import datetime, timezone
from typing import Any, Literal, TypeVar

import hamilton.node
from hamilton import graph_types
from hamilton.caching import fingerprinting
from hamilton.caching.cache_key import create_cache_key
from hamilton.caching.stores.base import (
    MetadataStore,
    ResultRetrievalError,
    ResultStore,
    search_data_adapter_registry,
)
from hamilton.caching.stores.file import FileResultStore
from hamilton.caching.stores.sqlite import SQLiteMetadataStore
from hamilton.function_modifiers.metadata import cache as cache_decorator
from hamilton.graph import FunctionGraph
from hamilton.lifecycle.base import (
    BaseDoNodeExecute,
    BasePostNodeExecute,
    BasePreGraphExecute,
    BasePreNodeExecute,
)

logger = logging.getLogger("hamilton.caching")

SENTINEL = object()
S = TypeVar("S", object, object)


CACHING_BEHAVIORS = Literal["default", "recompute", "disable", "ignore"]


[docs] class CachingBehavior(enum.Enum): """Behavior applied by the caching adapter DEFAULT: Try to retrieve result from cache instead of executing the node. If the node is executed, store the result. Compute the result data version and store it too. RECOMPUTE: Don't try to retrieve result from cache and always execute the node. Otherwise, behaves as default. Useful when nodes are stochastic (e.g., model training) or interact with external components (e.g., read from database). DISABLE: Node is executed as if the caching feature wasn't enabled. It never tries to retrieve results. Results are never stored nor versioned. Behaves like IGNORE, but the node remains a dependency for downstream nodes. This means downstream cache lookup will likely fail systematically (i.e., if the cache is empty). IGNORE: Node is executed as if the caching feature wasn't enable. It never tries to retrieve results. Results are never stored nor versioned. IGNORE means downstream nodes will ignore this node as a dependency for lookup. Ignoring clients and connections can be useful since they shouldn't directly impact the downstream results. """ DEFAULT = 1 RECOMPUTE = 2 DISABLE = 3 IGNORE = 4
[docs] @classmethod def from_string(cls, string: str) -> "CachingBehavior": """Create a caching behavior from a string of the enum value. This is leveraged by the ``hamilton.lifecycle.caching.SmartCacheAdapter`` and the ``hamilton.function_modifiers.metadata.cache`` decorator. .. code-block:: CachingBehavior.from_string("recompute") """ try: return cls[string.upper()] except KeyError as e: raise KeyError(f"{string} is an invalid `CachingBehavior` value") from e
class NodeRoleInTaskExecution(enum.Enum): """Identify the role of a node in task-based execution, in particular when ``Parallelizable/Collect`` are used. NOTE This is an internal construct and it will likely change in the future. STANDARD: when task-based execution is not used. All nodes and dependencies are STANDARD. EXPAND: node with type ``Parallelizable``. It returns an iterator where individual items need to be handled. Dependencies can only be OUTSIDE. COLLECT: node with type ``Collect``. It returns an iterable where individual items need to be handled. Dependencies can be INSIDE, OUTSIDE, or EXPAND OUTSIDE: "outside" of ``Parallelizable/Collect`` paths; handled like STANDARD in most cases. Dependencies can be OUTSIDE or COLLECT INSIDE: "inside" or "between" a ``Parallelizable/Collect`` nodes. Dependencies can be INSIDE, OUTSIDE, or EXPAND. """ STANDARD = 1 EXPAND = 2 COLLECT = 3 OUTSIDE = 4 INSIDE = 5
[docs] class CachingEventType(enum.Enum): """Event types logged by the caching adapter""" GET_DATA_VERSION = "get_data_version" SET_DATA_VERSION = "set_data_version" GET_CACHE_KEY = "get_cache_key" SET_CACHE_KEY = "set_cache_key" GET_RESULT = "get_result" SET_RESULT = "set_result" MISSING_RESULT = "missing_result" FAILED_RETRIEVAL = "failed_retrieval" EXECUTE_NODE = "execute_node" FAILED_EXECUTION = "failed_execution" RESOLVE_BEHAVIOR = "resolve_behavior" UNHASHABLE_DATA_VERSION = "unhashable_data_version" IS_OVERRIDE = "is_override" IS_INPUT = "is_input" IS_FINAL_VAR = "is_final_var" IS_DEFAULT_PARAMETER_VALUE = "is_default_parameter_value"
[docs] @dataclasses.dataclass(frozen=True) class CachingEvent: """Event logged by the caching adapter""" run_id: str actor: Literal["adapter", "metadata_store", "result_store"] event_type: CachingEventType node_name: str task_id: str | None = None msg: str | None = None value: Any | None = None timestamp: float = dataclasses.field( default_factory=lambda: datetime.now(timezone.utc).timestamp() ) def __str__(self) -> str: """Create a human-readable string format for `print()`""" string = self.node_name if self.task_id is not None: string += f"::{self.task_id}" string += f"::{self.actor}" string += f"::{self.event_type.value}" if self.msg: # this catches None and empty strings string += f"::{self.msg}" return string def as_dict(self): return dict( run_id=self.run_id, timestamp=self.timestamp, node_name=self.node_name, task_id=self.task_id, actor=self.actor, event_type=self.event_type.value, msg=self.msg, value=str(self.value) if self.value else self.value, )
# TODO we could add a "driver-level" kwarg to specify the cache format (e.g., parquet, JSON, etc.)
[docs] class HamiltonCacheAdapter( BaseDoNodeExecute, BasePreGraphExecute, BasePostNodeExecute, BasePreNodeExecute ): """Adapter enabling Hamilton's caching feature through ``Builder.with_cache()`` .. code-block:: python from hamilton import driver import my_dataflow dr = ( driver.Builder() .with_modules(my_dataflow) .with_cache() .build() ) # then, you can access the adapter via dr.cache """
[docs] def __init__( self, path: str | pathlib.Path = ".hamilton_cache", metadata_store: MetadataStore | None = None, result_store: ResultStore | None = None, default: Literal[True] | Collection[str] | None = None, recompute: Literal[True] | Collection[str] | None = None, ignore: Literal[True] | Collection[str] | None = None, disable: Literal[True] | Collection[str] | None = None, default_behavior: CACHING_BEHAVIORS | None = None, default_loader_behavior: CACHING_BEHAVIORS | None = None, default_saver_behavior: CACHING_BEHAVIORS | None = None, log_to_file: bool = False, **kwargs, ): """Initialize the cache adapter. :param path: path where the cache metadata and results will be stored :param metadata_store: BaseStore handling metadata for the cache adapter :param result_store: BaseStore caching dataflow execution results :param default: Set caching behavior to DEFAULT for specified node names. If True, apply to all nodes. :param recompute: Set caching behavior to RECOMPUTE for specified node names. If True, apply to all nodes. :param ignore: Set caching behavior to IGNORE for specified node names. If True, apply to all nodes. :param disable: Set caching behavior to DISABLE for specified node names. If True, apply to all nodes. :param default_behavior: Set the default caching behavior. :param default_loader_behavior: Set the default caching behavior `DataLoader` nodes. :param default_saver_behavior: Set the default caching behavior `DataSaver` nodes. :param log_to_file: If True, append cache event logs as they happen in JSONL format. """ self._path = path self.metadata_store = ( metadata_store if metadata_store is not None else SQLiteMetadataStore(path=path) ) self.result_store = ( result_store if result_store is not None else FileResultStore(path=str(path)) ) self.log_to_file = log_to_file if sum([default is True, recompute is True, disable is True, ignore is True]) > 1: raise ValueError( "Can only set one of (`default`, `recompute`, `disable`, `ignore`) to True. Please pass mutually exclusive sets of node names" ) self._default = default self._recompute = recompute self._disable = disable self._ignore = ignore self.default_behavior = default_behavior self.default_loader_behavior = default_loader_behavior self.default_saver_behavior = default_saver_behavior # attributes populated at execution time self.run_ids: list[str] = [] self._fn_graphs: dict[str, FunctionGraph] = {} # {run_id: graph} self._data_savers: dict[str, Collection[str]] = {} # {run_id: list[node_name]} self._data_loaders: dict[str, Collection[str]] = {} # {run_id: list[node_name]} self.behaviors: dict[ str, dict[str, CachingBehavior] ] = {} # {run_id: {node_name: behavior}} self.data_versions: dict[ str, dict[str, str | dict[str, str]] ] = {} # {run_id: {node_name: version}} or {run_id: {node_name: {task_id: version}}} self.code_versions: dict[str, dict[str, str]] = {} # {run_id: {node_name: version}} self.cache_keys: dict[ str, dict[str, str | dict[str, str]] ] = {} # {run_id: {node_name: key}} or {run_id: {node_name: {task_id: key}}} self._logs: dict[str, list[CachingEvent]] = {} # {run_id: [logs]}
@property def last_run_id(self): """Run id of the last started run. Not necessarily the last to complete.""" return self.run_ids[-1] def __getstate__(self) -> dict: """Serialization method required for multiprocessing and multithreading when using task-based execution with `Parallelizable/Collect` """ state = self.__dict__.copy() # store the classes to reinstantiate the same backend in __setstate__ state["metadata_store_cls"] = self.metadata_store.__class__ state["metadata_store_init"] = self.metadata_store.__getstate__() state["result_store_cls"] = self.result_store.__class__ state["result_store_init"] = self.result_store.__getstate__() del state["metadata_store"] del state["result_store"] return state def __setstate__(self, state: dict) -> None: """Serialization method required for multiprocessing and multithreading when using task-based execution with `Parallelizable/Collect`. Create new instances of metadata and result stores to have one connection per thread. """ # instantiate the backend from the class, then delete the attribute before # setting it on the adapter instance. self.metadata_store = state["metadata_store_cls"](**state["metadata_store_init"]) self.result_store = state["result_store_cls"](**state["result_store_init"]) del state["metadata_store_cls"] del state["result_store_cls"] self.__dict__.update(state) def _log_event( self, run_id: str, node_name: str, actor: Literal["adapter", "metadata_store", "result_store"], event_type: CachingEventType, msg: str | None = None, value: Any | None = None, task_id: str | None = None, ) -> None: """Add a single event to logs stored in state, keyed by run_id If global log level is set to logging.INFO, only log if event type is GET_RESULT or EXECUTE_NODE; If it is set to logging.DEBUG, log all events. If `SmartCacheAdapter.log_to_file` is set to True, log all events to a file in JSONL format. :param node_name: name of the node associated with the event :param task_id: optional identifier when using task-based execution. (node_name, task_id) is a primary key :param actor: component responsible for the event :param event_type: enum specifying what type of event (execute, retrieve, etc.) :param msg: additional message to display in the logs (e.g., via terminal) :param value: arbitrary value to include (typically a string for data version, code version, cache_key). Must be small and JSON-serializable. """ event = CachingEvent( run_id=run_id, node_name=node_name, task_id=task_id, actor=actor, event_type=event_type, msg=msg, value=value, ) if logger.isEnabledFor(logging.DEBUG): logger.debug(f"{event.__str__()}") elif logger.isEnabledFor(logging.INFO): if event.event_type in (CachingEventType.GET_RESULT, CachingEventType.EXECUTE_NODE): logger.info(f"{event.__str__()}") self._logs[run_id].append(event) if self.log_to_file: log_file_path = pathlib.Path(self.metadata_store._directory, "cache_logs.jsonl") json_line = json.dumps(event.as_dict()) with log_file_path.open("a") as f: f.write(json_line + "\n") def _log_by_node_name( self, run_id: str, level: Literal["debug", "info"] = "info" ) -> dict[str, list[str]]: """For a given run, group logs to key them by ``node_name`` or ``(node_name, run_id)`` if applicable.""" run_logs = collections.defaultdict(list) for event in self._logs[run_id]: if level == "info": if event.event_type not in ( CachingEventType.GET_RESULT, CachingEventType.EXECUTE_NODE, ): continue key = (event.node_name, event.task_id) if event.task_id else event.node_name run_logs[key].append(event) return dict(run_logs)
[docs] def logs(self, run_id: str | None = None, level: Literal["debug", "info"] = "info") -> dict: """Execution logs of the cache adapter. :param run_id: If ``None``, return all logged runs. If provided a ``run_id``, group logs by node. :param level: If ``"debug"`` log all events. If ``"info"`` only log if result is retrieved or executed. :return: a mapping between node/task and a list of logged events .. code-block:: python from hamilton import driver import my_dataflow dr = driver.Builder().with_modules(my_dataflow).with_cache().build() dr.execute(...) dr.execute(...) all_logs = dr.cache.logs() # all_logs is a dictionary with run_ids as keys and lists of CachingEvent as values. # { # run_id_1: [CachingEvent(...), CachingEvent(...)], # run_id_2: [CachingEvent(...), CachingEvent(...)], # } run_logs = dr.cache.logs(run_id=dr.last_run_id) # run_logs are keyed by ``node_name`` # {node_name: [CachingEvent(...), CachingEvent(...)], ...} # or ``(node_name, task_id)`` if task-based execution is used. # {(node_name_1, task_id_1): [CachingEvent(...), CachingEvent(...)], ...} """ if run_id: return self._log_by_node_name(run_id=run_id, level=level) logs = collections.defaultdict(list) for run_id, run_logs in self._logs.items(): for event in run_logs: if level == "info" and event.event_type not in ( CachingEventType.GET_RESULT, CachingEventType.EXECUTE_NODE, ): continue logs[run_id].append(event) return dict(logs)
@staticmethod def _view_run( fn_graph: FunctionGraph, logs, final_vars: list[str], inputs: dict, overrides: dict, output_file_path: str | None = None, ): """Create a Hamilton visualization of the execution and the cache hits/misses. This leverages the ``custom_style_function`` feature internally. """ from hamilton.driver import Driver # avoid circular import def _visualization_styling_function(*, node, node_class, logs): """Custom style function for the visualization.""" if any( event.event_type == CachingEventType.GET_RESULT for event in logs.get(node.name, []) ): style = ( {"penwidth": "3", "color": "#F06449", "fillcolor": "#ffffff"}, node_class, "from cache", ) else: style = ({}, node_class, None) return style return Driver._visualize_execution_helper( adapter=None, bypass_validation=True, render_kwargs={}, output_file_path=output_file_path, fn_graph=fn_graph, final_vars=final_vars, inputs=inputs, overrides=overrides, custom_style_function=functools.partial(_visualization_styling_function, logs=logs), ) # TODO make this work directly from the metadata_store too # visualization from logs is convenient when debugging someone else's issue
[docs] def view_run(self, run_id: str | None = None, output_file_path: str | None = None): """View the dataflow execution, including cache hits/misses. :param run_id: If ``None``, view the last run. If provided a ``run_id``, view that run. :param output_file_path: If provided a path, save the visualization to a file. .. code-block:: python from hamilton import driver import my_dataflow dr = driver.Builder().with_modules(my_dataflow).with_cache().build() # execute 3 times dr.execute(...) dr.execute(...) dr.execute(...) # view the last run dr.cache.view_run() # this is equivalent to dr.cache.view_run(run_id=dr.last_run_id) # get a specific run id run_id = dr.cache.run_ids[1] dr.cache.view_run(run_id=run_id) """ if run_id is None: run_id = self.last_run_id fn_graph = self._fn_graphs[run_id] logs = self.logs(run_id, level="debug") final_vars = [] inputs = {} overrides = {} for key, events in logs.items(): if isinstance(key, tuple): raise ValueError( "`.view()` is currently not supported for task-based execution. " "Please inspect the logs directly via `.logs(run_id=...)` for debugging." ) node_name = key if any(e.event_type == CachingEventType.IS_FINAL_VAR for e in events): final_vars.append(node_name) if any(e.event_type == CachingEventType.IS_INPUT for e in events): inputs[node_name] = None # the value doesn't matter, only the key of the dict continue elif any(e.event_type == CachingEventType.IS_OVERRIDE for e in events): overrides[node_name] = None # the value doesn't matter, only the key of the dict continue return self._view_run( fn_graph=fn_graph, logs=logs, final_vars=final_vars, inputs=inputs, overrides=overrides, output_file_path=output_file_path, )
def _get_node_role( self, run_id: str, node_name: str, task_id: str | None ) -> NodeRoleInTaskExecution: """Determine based on the node name and task_id if a node is part of parallel execution.""" if task_id is None: role = NodeRoleInTaskExecution.STANDARD else: node_type: hamilton.node.NodeType = self._fn_graphs[run_id].nodes[node_name].node_role if node_type == hamilton.node.NodeType.EXPAND: role = NodeRoleInTaskExecution.EXPAND elif node_type == hamilton.node.NodeType.COLLECT: role = NodeRoleInTaskExecution.COLLECT elif node_name == task_id: role = NodeRoleInTaskExecution.OUTSIDE else: role = NodeRoleInTaskExecution.INSIDE return role
[docs] def get_cache_key(self, run_id: str, node_name: str, task_id: str | None = None) -> str | S: """Get the ``cache_key`` stored in-memory for a specific ``run_id``, ``node_name``, and ``task_id``. This method is public-facing and can be used directly to inspect the cache. :param run_id: Id of the Hamilton execution run. :param node_name: Name of the node associated with the cache key. ``node_name`` is a unique identifier if task-based execution is not used. :param task_id: Id of the task when task-based execution is used. Then, the tuple ``(node_name, task_id)`` is a unique identifier. :return: The cache key if it exists, otherwise return a sentinel value. .. code-block:: python from hamilton import driver import my_dataflow dr = driver.Builder().with_modules(my_dataflow).with_cache().build() dr.execute(...) dr.cache.get_cache_key(run_id=dr.last_run_id, node_name="my_node", task_id=None) """ node_role = self._get_node_role(run_id=run_id, node_name=node_name, task_id=task_id) if node_role == NodeRoleInTaskExecution.INSIDE: cache_key = self.cache_keys[run_id].get(node_name, {}).get(task_id, SENTINEL) # type: ignore ; `task_id` can't be None else: cache_key = self.cache_keys[run_id].get(node_name, SENTINEL) cache_key = cache_key if cache_key is not SENTINEL else None self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="adapter", event_type=CachingEventType.GET_CACHE_KEY, msg="hit" if cache_key is not SENTINEL else "miss", value=cache_key, ) return cache_key
def _set_cache_key( self, run_id: str, node_name: str, cache_key: str, task_id: str | None = None ) -> None: """Set the ``cache_key`` stored in-memory for a specific ``run_id``, ``node_name``, and ``task_id``. When calling this method, ``cache_key`` must not be ``None``. """ assert cache_key is not None node_role = self._get_node_role(run_id=run_id, node_name=node_name, task_id=task_id) if node_role in ( NodeRoleInTaskExecution.STANDARD, NodeRoleInTaskExecution.OUTSIDE, NodeRoleInTaskExecution.EXPAND, NodeRoleInTaskExecution.COLLECT, ): self.cache_keys[run_id][node_name] = cache_key elif node_role == NodeRoleInTaskExecution.INSIDE: if self.cache_keys[run_id].get(node_name, SENTINEL) is SENTINEL: self.cache_keys[run_id][node_name] = {} self.cache_keys[run_id][node_name][task_id] = cache_key # type: ignore ; we just initialized the nested dict else: raise ValueError( f"Received `{node_role}`. Unhandled `NodeRoleInTaskExecution`, please report this bug." ) self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="adapter", event_type=CachingEventType.SET_CACHE_KEY, value=cache_key, ) def _get_memory_data_version( self, run_id: str, node_name: str, task_id: str | None = None ) -> str | S: """Get the ``data_version`` stored in-memory for a specific ``run_id``, ``node_name``, and ``task_id``. The behavior depends on the ``CacheBehavior`` (e.g., RECOMPUTE, IGNORE, DISABLE, DEFAULT) and the ``NodeRoleInTaskExecution`` of the node (e.g., STANDARD, OUTSIDE, INSIDE, EXPAND, COLLECT). :param run_id: Id of the Hamilton execution run. :param node_name: Name of the node associated with the cache key. ``node_name`` is a unique identifier if task-based execution is not used. :param task_id: Id of the task when task-based execution is used. Then, the tuple ``(node_name, task_id)`` is a unique identifier. """ node_role = self._get_node_role(run_id=run_id, node_name=node_name, task_id=task_id) if node_role in ( NodeRoleInTaskExecution.STANDARD, NodeRoleInTaskExecution.OUTSIDE, NodeRoleInTaskExecution.COLLECT, ): data_version = self.data_versions[run_id].get(node_name, SENTINEL) elif node_role == NodeRoleInTaskExecution.EXPAND: data_version = SENTINEL elif node_role == NodeRoleInTaskExecution.INSIDE: tasks_data_versions = self.data_versions[run_id].get(node_name, SENTINEL) if isinstance(tasks_data_versions, dict): data_version = tasks_data_versions.get(task_id, SENTINEL) else: data_version = SENTINEL else: raise ValueError( f"Received `{node_role}`. Unhandled `NodeRoleInTaskExecution`, please report this bug." ) self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="adapter", event_type=CachingEventType.GET_DATA_VERSION, msg="hit" if data_version is not SENTINEL else "miss", ) return data_version def _get_stored_data_version( self, run_id: str, node_name: str, cache_key: str, task_id: str | None = None ) -> str | S: """Get the ``data_version`` stored in the metadata store associated with the ``cache_key``. The ``run_id``, ``node_name``, and ``task_id`` are included only for logging purposes. """ data_version = self.metadata_store.get(cache_key=cache_key) data_version = SENTINEL if data_version is None else data_version self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="metadata_store", event_type=CachingEventType.GET_DATA_VERSION, msg="hit" if data_version is not SENTINEL else "miss", ) return data_version
[docs] def get_data_version( self, run_id: str, node_name: str, cache_key: str | None = None, task_id: str | None = None, ) -> str | S: """Get the ``data_version`` for a specific ``run_id``, ``node_name``, and ``task_id``. This method is public-facing and can be used directly to inspect the cache. This will check data versions stored both in-memory and in the metadata store. :param run_id: Id of the Hamilton execution run. :param node_name: Name of the node associated with the data version. ``node_name`` is a unique identifier if task-based execution is not used. :param task_id: Id of the task when task-based execution is used. Then, the tuple ``(node_name, task_id)`` is a unique identifier. :return: The data version if it exists, otherwise return a sentinel value. ..code-block:: python from hamilton import driver import my_dataflow dr = driver.Builder().with_modules(my_dataflow).with_cache().build() dr.execute(...) dr.cache.get_data_version(run_id=dr.last_run_id, node_name="my_node", task_id=None) """ data_version = self._get_memory_data_version( run_id=run_id, node_name=node_name, task_id=task_id ) if data_version is SENTINEL and cache_key is not None: data_version = self._get_stored_data_version( run_id=run_id, node_name=node_name, task_id=task_id, cache_key=cache_key ) return data_version
def _set_memory_metadata( self, run_id: str, node_name: str, data_version: str, task_id: str | None = None ) -> None: """Set in-memory data_version whether a task_id is specified or not""" assert data_version is not None node_role = self._get_node_role(run_id=run_id, node_name=node_name, task_id=task_id) if node_role in ( NodeRoleInTaskExecution.STANDARD, NodeRoleInTaskExecution.OUTSIDE, NodeRoleInTaskExecution.COLLECT, ): self.data_versions[run_id][node_name] = data_version elif node_role == NodeRoleInTaskExecution.EXPAND: self.data_versions[run_id][node_name] = {} elif node_role == NodeRoleInTaskExecution.INSIDE: if self.data_versions[run_id].get(node_name, SENTINEL) is SENTINEL: self.data_versions[run_id][node_name] = {} self.data_versions[run_id][node_name][task_id] = data_version # type: ignore ; we just initialized the nested dict else: raise ValueError( f"Received `{node_role}`. Unhandled `NodeRoleInTaskExecution`, please report this bug." ) self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="adapter", event_type=CachingEventType.SET_DATA_VERSION, value=data_version, ) def _set_stored_metadata( self, run_id: str, node_name: str, cache_key: str, data_version: str, task_id: str | None = None, ) -> None: """Set data_version in the metadata store associated with the cache_key""" self.metadata_store.set( run_id=run_id, node_name=node_name, code_version=self.code_versions[run_id][node_name], data_version=data_version, cache_key=cache_key, ) self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="metadata_store", event_type=CachingEventType.SET_DATA_VERSION, value=data_version, ) def _version_data( self, node_name: str, run_id: str, result: Any, task_id: str | None = None ) -> str: """Create a unique data version for the result""" data_version = fingerprinting.hash_value(result) if data_version == fingerprinting.UNHASHABLE: self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="adapter", event_type=CachingEventType.UNHASHABLE_DATA_VERSION, msg=f"unhashable type {type(result)}; set CachingBehavior.IGNORE to silence warning", value=data_version, ) logger.warning( f"Node `{node_name}` has unhashable result of type `{type(result)}`. " "Set `CachingBehavior.IGNORE` or register a versioning function to silence warning. " "Learn more: https://hamilton.apache.org/concepts/caching/#caching-behavior\n" ) # if the data version is unhashable, we need to set a random suffix to the cache_key # to prevent the cache from thinking this value is constant, causing a cache hit. data_version = "<unhashable>" + f"_{uuid.uuid4()}" return data_version
[docs] def version_data(self, result: Any, run_id: str = None) -> str: """Create a unique data version for the result This is a user-facing method. """ # stuff the internal function call to not log event return self._version_data(result=result, run_id=run_id, node_name=None)
[docs] def version_code(self, node_name: str, run_id: str | None = None) -> str: """Create a unique code version for the source code defining the node""" run_id = self.last_run_id if run_id is None else run_id node = self._fn_graphs[run_id].nodes[node_name] return graph_types.HamiltonNode.from_node(node).version # type: ignore
def _execute_node( self, run_id: str, node_name: str, node_callable: Callable, node_kwargs: dict[str, Any], task_id: str | None = None, ) -> Any: """Simple wrapper that logs the regular execution of a node.""" logger.debug(node_name) result = node_callable(**node_kwargs) self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="adapter", event_type=CachingEventType.EXECUTE_NODE, ) return result @staticmethod def _resolve_node_behavior( node: hamilton.node.Node, default: Collection[str] | None = None, disable: Collection[str] | None = None, recompute: Collection[str] | None = None, ignore: Collection[str] | None = None, default_behavior: CACHING_BEHAVIORS = "default", default_loader_behavior: CACHING_BEHAVIORS = "default", default_saver_behavior: CACHING_BEHAVIORS = "default", ) -> CachingBehavior: """Determine the cache behavior of a node. Behavior specified via the ``Builder`` has precedence over the ``@cache`` decorator. Otherwise, set the ``DEFAULT`` behavior. If the node is `Parallelizable` enforce the ``RECOMPUTE`` behavior to ensure yielded items are versioned individually. """ if node.node_role == hamilton.node.NodeType.EXPAND: return CachingBehavior.RECOMPUTE behavior_from_tag = node.tags.get(cache_decorator.BEHAVIOR_KEY, SENTINEL) if behavior_from_tag is not SENTINEL: behavior_from_tag = CachingBehavior.from_string(behavior_from_tag) behavior_from_driver = SENTINEL for behavior, node_set in ( (CachingBehavior.DEFAULT, default), (CachingBehavior.DISABLE, disable), (CachingBehavior.RECOMPUTE, recompute), (CachingBehavior.IGNORE, ignore), ): # guard against default None value if node_set is None: continue if node.name in node_set: if behavior_from_driver is not SENTINEL: raise ValueError( f"Multiple caching behaviors specified by Driver for node: {node.name}" ) behavior_from_driver = behavior if behavior_from_driver is not SENTINEL: return behavior_from_driver elif behavior_from_tag is not SENTINEL: return behavior_from_tag elif node.tags.get("hamilton.data_loader"): return CachingBehavior.from_string(default_loader_behavior) elif node.tags.get("hamilton.data_saver"): return CachingBehavior.from_string(default_saver_behavior) else: return CachingBehavior.from_string(default_behavior)
[docs] def resolve_behaviors(self, run_id: str) -> dict[str, CachingBehavior]: """Resolve the caching behavior for each node based on the ``@cache`` decorator and the ``Builder.with_cache()`` parameters for a specific ``run_id``. This is a user-facing method. Behavior specified via ``Builder.with_cache()`` have precedence. If no parameters are specified, the ``CachingBehavior.DEFAULT`` is used. If a node is ``Parallelizable`` (i.e., ``@expand``), the ``CachingBehavior`` is set to ``CachingBehavior.RECOMPUTE`` to ensure the yielded items are versioned individually. Internally, this uses the ``FunctionGraph`` stored for each ``run_id`` and logs the resolved caching behavior for each node. :param run_id: Id of the Hamilton execution run. :return: A dictionary of ``{node name: caching behavior}``. """ graph = self._fn_graphs[run_id] _default = self._default _disable = self._disable _recompute = self._recompute _ignore = self._ignore if _default is True: _default = [n.name for n in graph.get_nodes()] elif _disable is True: _disable = [n.name for n in graph.get_nodes()] elif _recompute is True: _recompute = [n.name for n in graph.get_nodes()] elif _ignore is True: _ignore = [n.name for n in graph.get_nodes()] default_behavior = "default" if self.default_behavior is not None: default_behavior = self.default_behavior default_loader_behavior = default_behavior if self.default_loader_behavior is not None: default_loader_behavior = self.default_loader_behavior default_saver_behavior = default_behavior if self.default_saver_behavior is not None: default_saver_behavior = self.default_saver_behavior behaviors = {} for node in graph.get_nodes(): behavior = HamiltonCacheAdapter._resolve_node_behavior( node=node, default=_default, disable=_disable, recompute=_recompute, ignore=_ignore, default_behavior=default_behavior, default_loader_behavior=default_loader_behavior, default_saver_behavior=default_saver_behavior, ) behaviors[node.name] = behavior self._log_event( run_id=run_id, node_name=node.name, task_id=None, actor="adapter", event_type=CachingEventType.RESOLVE_BEHAVIOR, value=behavior, ) # need to handle materializers via a second pass to copy the behavior # of their "main node" for node in graph.get_nodes(): if node.tags.get("hamilton.data_loader") is True: main_node = node.tags["hamilton.data_loader.node"] if main_node == node.name: continue # solution for `@dataloader` and `from_` if behaviors.get(main_node, None) is not None: behaviors[node.name] = behaviors[main_node] # this hacky section is required to support @load_from and provide # a unified pattern to specify behavior from the module or the driver else: behaviors[node.name] = HamiltonCacheAdapter._resolve_node_behavior( # we create a fake node, only its name matters node=hamilton.node.Node( name=main_node, typ=str, callabl=lambda: None, tags=node.tags.copy(), ), default=_default, disable=_disable, recompute=_recompute, ignore=_ignore, default_behavior=default_loader_behavior, ) self._data_loaders[run_id].append(main_node) if node.tags.get("hamilton.data_saver", None) is not None: self._data_savers[run_id].append(node.name) return behaviors
[docs] def resolve_code_versions( self, run_id: str, final_vars: list[str] | None = None, inputs: dict[str, Any] | None = None, overrides: dict[str, Any] | None = None, ) -> dict[str, str]: """Resolve the code version for each node for a specific ``run_id``. This is a user-facing method. If ``final_vars`` is None, all nodes will be versioned. If ``final_vars`` is provided, the ``inputs`` and ``overrides`` are used to determine the execution path and only version the code for these nodes. :param run_id: Id of the Hamilton execution run. :param final_vars: Nodes requested for execution. :param inputs: Input node values. :param overrides: Override node values. :return: A dictionary of ``{node name: code version}``. """ graph = self._fn_graphs[run_id] final_vars = [] if final_vars is None else final_vars inputs = {} if inputs is None else inputs overrides = {} if overrides is None else overrides node_selection = graph.get_nodes() if len(final_vars) > 0: all_nodes, user_defined_nodes = graph.get_upstream_nodes(final_vars, inputs, overrides) node_selection = set(all_nodes) - set(user_defined_nodes) return { node.name: self.version_code(run_id=run_id, node_name=node.name) for node in node_selection }
def _process_input(self, run_id: str, node_name: str, value: Any) -> None: """Process input nodes to version data and code. To enable caching, input values must be versioned. Since inputs have no associated code, set a constant "code version" ``f"input__{node_name}"`` that uniquely identifies this input. """ data_version = self._version_data(node_name=node_name, run_id=run_id, result=value) self.code_versions[run_id][node_name] = f"input__{node_name}" self.data_versions[run_id][node_name] = data_version self._log_event( run_id=run_id, node_name=node_name, task_id=None, actor="adapter", event_type=CachingEventType.IS_INPUT, value=data_version, ) def _process_override(self, run_id: str, node_name: str, value: Any) -> None: """Process override nodes to version data and code. To enable caching, override values must be versioned. As opposed to executed nodes, code and data versions for overrides are not stored because their value is user provided and isn't necessarily tied to the code. """ data_version = self._version_data(node_name=node_name, run_id=run_id, result=value) self.data_versions[run_id][node_name] = data_version self._log_event( run_id=run_id, node_name=node_name, task_id=None, actor="adapter", event_type=CachingEventType.IS_OVERRIDE, value=data_version, ) @staticmethod def _resolve_default_parameter_values( node_: hamilton.node.Node, node_kwargs: dict[str, Any] ) -> dict[str, Any]: """ If a node uses the function's default parameter values, they won't be part of the node_kwargs. To ensure a consistent `cache_key` we want to retrieve default parameter values if they're used """ resolved_kwargs = node_kwargs.copy() for param_name, param_value in node_.default_parameter_values.items(): # if the `param_name` not in `node_kwargs`, it means the node uses the default # parameter value if param_name not in node_kwargs.keys(): resolved_kwargs.update(**{param_name: param_value}) return resolved_kwargs
[docs] def pre_graph_execute( self, *, run_id: str, graph: FunctionGraph, final_vars: list[str], inputs: dict[str, Any], overrides: dict[str, Any], ): """Set up the state of the adapter for a new execution. Most attributes need to be keyed by run_id to prevent potential conflicts because the same adapter instance is shared between across all ``Driver.execute()`` calls. """ self.run_ids.append(run_id) self.metadata_store.initialize(run_id) self._logs[run_id] = [] self._fn_graphs[run_id] = graph self.data_versions[run_id] = {} self.cache_keys[run_id] = {} self.code_versions[run_id] = self.resolve_code_versions( run_id=run_id, final_vars=final_vars, inputs=inputs, overrides=overrides ) # the empty `._data_loaders` and `._data_savers` need to be instantiated before calling # `self.resolve_behaviors` because it appends to them self._data_loaders[run_id] = [] self._data_savers[run_id] = [] self.behaviors[run_id] = self.resolve_behaviors(run_id=run_id) # final vars are logged to be retrieved by the ``.view_run()`` method for final_var in final_vars: self._log_event( run_id=run_id, node_name=final_var, task_id=None, actor="adapter", event_type=CachingEventType.IS_FINAL_VAR, ) if inputs: for node_name, value in inputs.items(): self._process_input(run_id, node_name, value) if overrides: for node_name, value in overrides.items(): self._process_override(run_id, node_name, value)
[docs] def pre_node_execute( self, *, run_id: str, node_: hamilton.node.Node, kwargs: dict[str, Any], task_id: str | None = None, **future_kwargs, ): """Before node execution or retrieval, create the cache_key and set it in memory. The cache_key is created based on the node's code version and its dependencies' data versions. Collecting ``data_version`` for upstream dependencies requires handling special cases when task-based execution is used: - If the current node is ``COLLECT`` , the dependency annotated with ``Collect[]`` needs to be versioned item by item instead of versioning the full container. This is because the collect order is inconsistent. - If the current node is ``INSIDE`` and the dependency is ``EXPAND``, this means the ``kwargs`` dictionary contains a single item. We need to version this individual item because it will not be available from "inside" the branch for some executors (multiprocessing, multithreading) because they lose access to the data_versions of ``OUTSIDE`` nodes stored in ``self.data_versions``. """ node_name = node_.name node_kwargs = HamiltonCacheAdapter._resolve_default_parameter_values(node_, kwargs) if self.behaviors[run_id][node_name] == CachingBehavior.IGNORE: return # won't need the cache_key for either result retrieval or storage if self.behaviors[run_id][node_name] == CachingBehavior.DISABLE: return node_role = self._get_node_role(run_id=run_id, node_name=node_name, task_id=task_id) collected_name = ( node_.collect_dependency if node_role == NodeRoleInTaskExecution.COLLECT else SENTINEL ) dependencies_data_versions = {} for dep_name, dep_value in node_kwargs.items(): # resolve caching behaviors if self.behaviors[run_id][dep_name] == CachingBehavior.IGNORE: # setting the data_version to "<ignore>" in the cache_key means that # the value of the dependency appears constant to this node dependencies_data_versions[dep_name] = "<ignore>" continue elif self.behaviors[run_id][dep_name] == CachingBehavior.DISABLE: # setting the data_version to "<disable>" with a random suffix in the # cache_key means the current node will be a cache miss and forced to recompute dependencies_data_versions[dep_name] = "<disable>" + f"_{uuid.uuid4()}" continue # resolve NodeRoleInTaskExecution if task_id is None: dep_role = NodeRoleInTaskExecution.STANDARD else: # want to check if dependency is an EXPAND node. We must not pass the current `task_id` dep_role = self._get_node_role( run_id=run_id, node_name=dep_name, task_id="<placeholder>" ) # if dep_role == NodeRoleInTaskExecution.STANDARD: if dep_name == collected_name: # the collected value should be hashed based on the items, not the container items_data_versions = [self.version_data(item, run_id=run_id) for item in dep_value] dep_data_version = fingerprinting.hash_value(sorted(items_data_versions)) elif dep_role == NodeRoleInTaskExecution.EXPAND: # if the dependency is `EXPAND`, the kwarg received is a single item yielded by the iterator # rather than the full iterable. We must version it directly, similar to a top-level input dep_data_version = self.version_data(dep_value, run_id=run_id) else: tasks_data_versions = self._get_memory_data_version( run_id=run_id, node_name=dep_name, task_id=None ) if tasks_data_versions is SENTINEL: dep_data_version = self.version_data(dep_value, run_id=run_id) elif isinstance(tasks_data_versions, dict): dep_data_version = tasks_data_versions.get(task_id) else: dep_data_version = tasks_data_versions if dep_data_version == fingerprinting.UNHASHABLE: # if the data version is unhashable, we need to set a random suffix to the cache_key # to prevent the cache from thinking this value is constant, causing a cache hit. dep_data_version = "<unhashable>" + f"_{uuid.uuid4()}" dependencies_data_versions[dep_name] = dep_data_version # create cache_key before execution; will be reused during and after execution cache_key = create_cache_key( node_name=node_name, code_version=self.code_versions[run_id][node_name], dependencies_data_versions=dependencies_data_versions, ) self._set_cache_key( run_id=run_id, node_name=node_name, task_id=task_id, cache_key=cache_key )
[docs] def do_node_execute( self, *, run_id: str, node_: hamilton.node.Node, kwargs: dict[str, Any], task_id: str | None = None, **future_kwargs, ): """Try to retrieve stored result from previous executions or execute the node. Use the previously created cache_key to retrieve the data_version from memory or the metadata_store. If data_version is retrieved try to retrieve the result. If it fails, execute the node. Else, execute the node. """ node_name = node_.name node_callable = node_.callable node_kwargs = HamiltonCacheAdapter._resolve_default_parameter_values(node_, kwargs) if self.behaviors[run_id][node_name] in ( CachingBehavior.DISABLE, CachingBehavior.IGNORE, CachingBehavior.RECOMPUTE, ): result = self._execute_node( run_id=run_id, node_name=node_name, node_callable=node_callable, node_kwargs=node_kwargs, task_id=task_id, ) if self.behaviors[run_id][node_name] in ( CachingBehavior.RECOMPUTE, CachingBehavior.IGNORE, ): cache_key = self.get_cache_key(run_id=run_id, node_name=node_name, task_id=task_id) # nodes collected in `._data_loaders` return tuples of (result, metadata) # where metadata often includes a timestamp. To ensure we provide a consistent # `data_version` / hash, we only hash the result part of the materializer return # value and discard the metadata. if node_name in self._data_loaders[run_id] and isinstance(result, tuple): result = result[0] data_version = self._version_data(node_name=node_name, run_id=run_id, result=result) # nodes collected in `._data_savers` return a dictionary of metadata # this metadata often includes a timestamp, leading to an unstable hash. # we do not version nor store the metadata. This node is executed for its # external effect of saving a file if node_name in self._data_savers[run_id]: data_version = f"{node_name}__metadata" self._set_memory_metadata( run_id=run_id, node_name=node_name, task_id=task_id, data_version=data_version ) self._set_stored_metadata( run_id=run_id, node_name=node_name, task_id=task_id, cache_key=cache_key, data_version=data_version, ) return result # cache_key is set in `pre_node_execute` cache_key = self.get_cache_key(run_id=run_id, node_name=node_name, task_id=task_id) # retrieve data version from memory or metadata_store data_version = self.get_data_version( run_id=run_id, node_name=node_name, task_id=task_id, cache_key=cache_key ) need_to_compute_node = False if data_version is SENTINEL: # must execute: data_version not found in memory or in metadata_store need_to_compute_node = True elif data_version == fingerprinting.UNHASHABLE: # must execute: the retrieved data_version is UNHASHABLE, therefore it isn't stored. need_to_compute_node = True elif self.result_store.exists(data_version) is False: # must execute: data_version retrieved, but result store can't find result need_to_compute_node = True self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="result_store", event_type=CachingEventType.MISSING_RESULT, value=data_version, ) else: # try to retrieve: data_version retrieve, result store found result try: # successful retrieval: retrieve the result; potentially load using the DataLoader if e.g.,``@cache(format="json")`` result = self.result_store.get(data_version=data_version) self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="result_store", event_type=CachingEventType.GET_RESULT, msg="hit", value=data_version, ) # set the data_version previously retrieved (could be from memory or store) self._set_memory_metadata( run_id=run_id, node_name=node_name, task_id=task_id, data_version=data_version ) except ResultRetrievalError: # failed retrieval: despite finding the result, probably failed loading data using DataLoader if e.g.,``@cache(format="json")`` self.metadata_store.delete(cache_key=cache_key) self.result_store.delete(data_version) need_to_compute_node = True if need_to_compute_node is True: result = self._execute_node( run_id=run_id, node_name=node_name, node_callable=node_callable, node_kwargs=node_kwargs, task_id=task_id, ) # nodes collected in `._data_loaders` return tuples of (result, metadata) # where metadata often includes a timestamp. To ensure we provide a consistent # `data_version` / hash, we only hash the result part of the materializer return # value and discard the metadata. if node_name in self._data_loaders[run_id] and isinstance(result, tuple): result = result[0] data_version = self._version_data(node_name=node_name, run_id=run_id, result=result) # nodes collected in `._data_savers` return a dictionary of metadata # this metadata often includes a timestamp, leading to an unstable hash. # we do not version nor store the metadata. This node is executed for its # external effect of saving a file if node_name in self._data_savers[run_id]: data_version = f"{node_name}__metadata" self._set_memory_metadata( run_id=run_id, node_name=node_name, task_id=task_id, data_version=data_version ) self._set_stored_metadata( run_id=run_id, node_name=node_name, task_id=task_id, cache_key=cache_key, data_version=data_version, ) return result
[docs] def post_node_execute( self, *, run_id: str, node_: hamilton.node.Node, result: str | None, success: bool = True, error: Exception | None = None, task_id: str | None = None, **future_kwargs, ): """Get the cache_key and data_version stored in memory (respectively from pre_node_execute and do_node_execute) and store the result in result_store if it doesn't exist. """ node_name = node_.name if success is False: self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="adapter", event_type=CachingEventType.FAILED_EXECUTION, msg=f"{error}", ) return if self.behaviors[run_id][node_name] in ( CachingBehavior.DEFAULT, CachingBehavior.RECOMPUTE, CachingBehavior.IGNORE, ): cache_key = self.get_cache_key(run_id=run_id, node_name=node_name, task_id=task_id) data_version = self.get_data_version( run_id=run_id, node_name=node_name, task_id=task_id, cache_key=cache_key ) assert data_version is not SENTINEL # TODO clean up this logic # check if a materialized file exist before writing results # when using `@cache(format="json")` cache_format = ( self._fn_graphs[run_id] .nodes[node_name] .tags.get(cache_decorator.FORMAT_KEY, SENTINEL) ) if cache_format is not SENTINEL: saver_cls, loader_cls = search_data_adapter_registry( name=cache_format, type_=type(result) ) # type: ignore materialized_path = self.result_store._materialized_path(data_version, saver_cls) materialized_path_missing = not materialized_path.exists() else: saver_cls, loader_cls = None, None materialized_path_missing = False result_missing = not self.result_store.exists(data_version) if result_missing or materialized_path_missing: self.result_store.set( data_version=data_version, result=result, saver_cls=saver_cls, loader_cls=loader_cls, ) self._log_event( run_id=run_id, node_name=node_name, task_id=task_id, actor="result_store", event_type=CachingEventType.SET_RESULT, value=data_version, )