Source code for hamilton.caching.stores.memory

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from collections.abc import Sequence
from typing import Any

try:
    from typing import override
except ImportError:
    override = lambda x: x  # noqa E731

from hamilton.caching.cache_key import decode_key

from .base import MetadataStore, ResultStore, StoredResult
from .file import FileResultStore
from .sqlite import SQLiteMetadataStore


[docs] class InMemoryMetadataStore(MetadataStore): def __init__(self) -> None: self._data_versions: dict[str, str] = {} # {cache_key: data_version} self._cache_keys_by_run: dict[str, list[str]] = {} # {run_id: [cache_key]} self._run_ids: list[str] = [] @override def __len__(self) -> int: """Number of unique ``cache_key`` values.""" return len(self._data_versions.keys())
[docs] @override def exists(self, cache_key: str) -> bool: """Indicate if ``cache_key`` exists and it can retrieve a ``data_version``.""" return cache_key in self._data_versions.keys()
[docs] @override def initialize(self, run_id: str) -> None: """Set up and log the beginning of the run.""" self._cache_keys_by_run[run_id] = [] self._run_ids.append(run_id)
[docs] @override def set(self, cache_key: str, data_version: str, run_id: str, **kwargs) -> Any | None: """Set the ``data_version`` for ``cache_key`` and associate it with the ``run_id``.""" self._data_versions[cache_key] = data_version self._cache_keys_by_run[run_id].append(cache_key)
[docs] @override def get(self, cache_key: str) -> str | None: """Retrieve the ``data_version`` for ``cache_key``.""" return self._data_versions.get(cache_key, None)
[docs] @override def delete(self, cache_key: str) -> None: """Delete the ``data_version`` for ``cache_key``.""" del self._data_versions[cache_key]
[docs] @override def delete_all(self) -> None: """Delete all stored metadata.""" self._data_versions.clear()
[docs] def persist_to(self, metadata_store: MetadataStore | None = None) -> None: """Persist in-memory metadata using another MetadataStore implementation. :param metadata_store: MetadataStore implementation to use for persistence. If None, a SQLiteMetadataStore is created with the default path "./.hamilton_cache". .. code-block:: python from hamilton import driver from hamilton.caching.stores.sqlite import SQLiteMetadataStore from hamilton.caching.stores.memory import InMemoryMetadataStore import my_dataflow dr = ( driver.Builder() .with_modules(my_dataflow) .with_cache(metadata_store=InMemoryMetadataStore()) .build() ) # execute the Driver several time. This will populate the in-memory metadata store dr.execute(...) # persist to disk in-memory metadata dr.cache.metadata_store.persist_to(SQLiteMetadataStore(path="./.hamilton_cache")) """ if metadata_store is None: metadata_store = SQLiteMetadataStore(path="./.hamilton_cache") for run_id in self._run_ids: metadata_store.initialize(run_id) for run_id, cache_keys in self._cache_keys_by_run.items(): for cache_key in cache_keys: data_version = self._data_versions[cache_key] metadata_store.set( cache_key=cache_key, data_version=data_version, run_id=run_id, )
[docs] @classmethod def load_from(cls, metadata_store: MetadataStore) -> "InMemoryMetadataStore": """Load in-memory metadata from another MetadataStore instance. :param metadata_store: MetadataStore instance to load from. :return: InMemoryMetadataStore copy of the ``metadata_store``. .. code-block:: python from hamilton import driver from hamilton.caching.stores.sqlite import SQLiteMetadataStore from hamilton.caching.stores.memory import InMemoryMetadataStore import my_dataflow sqlite_metadata_store = SQLiteMetadataStore(path="./.hamilton_cache") in_memory_metadata_store = InMemoryMetadataStore.load_from(sqlite_metadata_store) # create the Driver with the in-memory metadata store dr = ( driver.Builder() .with_modules(my_dataflow) .with_cache(metadata_store=in_memory_metadata_store) .build() ) """ in_memory_metadata_store = InMemoryMetadataStore() for run_id in metadata_store.get_run_ids(): in_memory_metadata_store.initialize(run_id) for node_metadata in metadata_store.get_run(run_id): in_memory_metadata_store.set( cache_key=node_metadata["cache_key"], data_version=node_metadata["data_version"], run_id=run_id, ) return in_memory_metadata_store
[docs] @override def get_run_ids(self) -> list[str]: """Return a list of all ``run_id`` values stored.""" return self._run_ids
[docs] @override def get_run(self, run_id: str) -> list[dict[str, str]]: """Return a list of node metadata associated with a run.""" if self._cache_keys_by_run.get(run_id, None) is None: raise IndexError(f"Run ID not found: {run_id}") nodes_metadata = [] for cache_key in self._cache_keys_by_run[run_id]: decoded_key = decode_key(cache_key) nodes_metadata.append( dict( cache_key=cache_key, data_version=self._data_versions[cache_key], node_name=decoded_key["node_name"], code_version=decoded_key["code_version"], dependencies_data_versions=decoded_key["dependencies_data_versions"], ) ) return nodes_metadata
[docs] class InMemoryResultStore(ResultStore): def __init__(self, persist_on_exit: bool = False) -> None: self._results: dict[str, StoredResult] = {} # {data_version: result}
[docs] @override def exists(self, data_version: str) -> bool: return data_version in self._results.keys()
# TODO handle materialization
[docs] @override def set(self, data_version: str, result: Any, **kwargs) -> None: self._results[data_version] = StoredResult.new(value=result)
[docs] @override def get(self, data_version: str) -> Any | None: stored_result = self._results.get(data_version, None) if stored_result is None: return None return stored_result.value
[docs] @override def delete(self, data_version: str) -> None: del self._results[data_version]
[docs] @override def delete_all(self) -> None: self._results.clear()
def delete_expired(self) -> None: to_delete = [ data_version for data_version, stored_result in self._results.items() if stored_result.expired ] # first collect keys then delete because you can delete from dictionary # as you iterate through it for data_version in to_delete: self.delete(data_version)
[docs] def persist_to(self, result_store: ResultStore | None = None) -> None: """Persist in-memory results using another ``ResultStore`` implementation. :param result_store: ResultStore implementation to use for persistence. If None, a FileResultStore is created with the default path "./.hamilton_cache". """ if result_store is None: result_store = FileResultStore(path="./.hamilton_cache") for data_version, stored_result in self._results.items(): result_store.set(data_version, stored_result.value)
[docs] @classmethod def load_from( cls, result_store: ResultStore, metadata_store: MetadataStore | None = None, data_versions: Sequence[str] | None = None, ) -> "InMemoryResultStore": """Load in-memory results from another ResultStore instance. Since result stores do not store an index of their keys, you must provide a ``MetadataStore`` instance or a list of ``data_version`` for which results should be loaded in memory. :param result_store: ``ResultStore`` instance to load results from. :param metadata_store: ``MetadataStore`` instance from which all ``data_version`` are retrieved. :return: InMemoryResultStore copy of the ``result_store``. .. code-block:: python from hamilton import driver from hamilton.caching.stores.sqlite import SQLiteMetadataStore from hamilton.caching.stores.memory import InMemoryMetadataStore import my_dataflow sqlite_metadata_store = SQLiteMetadataStore(path="./.hamilton_cache") in_memory_metadata_store = InMemoryMetadataStore.load_from(sqlite_metadata_store) # create the Driver with the in-memory metadata store dr = ( driver.Builder() .with_modules(my_dataflow) .with_cache(metadata_store=in_memory_metadata_store) .build() ) """ if metadata_store is None and data_versions is None: raise ValueError( "A `metadata_store` or `data_versions` must be provided to load results." ) in_memory_result_store = InMemoryResultStore() data_versions_to_retrieve = set() if data_versions is not None: data_versions_to_retrieve.update(data_versions) if metadata_store is not None: for run_id in metadata_store.get_run_ids(): for node_metadata in metadata_store.get_run(run_id): data_versions_to_retrieve.add(node_metadata["data_version"]) for data_version in data_versions_to_retrieve: # TODO disambiguate "result is None" from the sentinel value when `data_version` # is not found in `result_store`. result = result_store.get(data_version) in_memory_result_store.set(data_version, result) return in_memory_result_store