# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import abc
import pickle
from collections.abc import Sequence
from datetime import datetime, timedelta, timezone
from typing import Any
from hamilton.htypes import custom_subclass_check
from hamilton.io.data_adapters import DataLoader, DataSaver
from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY
[docs]
class ResultRetrievalError(Exception):
"""Raised by the SmartCacheAdapter when ResultStore.get() fails."""
# TODO Currently, this check is done when data needs to be saved.
# Ideally, it would be done earlier in the caching lifecycle.
[docs]
def search_data_adapter_registry(
name: str, type_: type
) -> tuple[type[DataSaver], type[DataLoader]]:
"""Find pair of DataSaver and DataLoader registered with `name` and supporting `type_`"""
if name not in SAVER_REGISTRY or name not in LOADER_REGISTRY:
raise KeyError(
f"{name} isn't associated to both a DataLoader and a DataSaver. "
"Default saver/loader pairs include `json`, `file`, `pickle`, `parquet`, `csv`, "
"`feather`, `orc`, `excel`. More pairs may be available through plugins."
)
try:
saver_cls = next(
saver_cls
for saver_cls in SAVER_REGISTRY[name]
if any(
custom_subclass_check(type_, applicable_type)
for applicable_type in saver_cls.applicable_types()
)
)
except StopIteration as e:
raise KeyError(f"{name} doesn't have any DataSaver supporting type {type_}") from e
try:
loader_cls = next(
loader_cls
for loader_cls in LOADER_REGISTRY[name]
if any(
custom_subclass_check(type_, applicable_type)
for applicable_type in loader_cls.applicable_types()
)
)
except StopIteration as e:
raise KeyError(f"{name} doesn't have any DataLoader supporting type {type_}") from e
return saver_cls, loader_cls
[docs]
class ResultStore(abc.ABC):
[docs]
@abc.abstractmethod
def set(self, data_version: str, result: Any, **kwargs) -> None:
"""Store ``result`` keyed by ``data_version``."""
[docs]
@abc.abstractmethod
def get(self, data_version: str, **kwargs) -> Any | None:
"""Try to retrieve ``result`` keyed by ``data_version``.
If retrieval misses, return ``None``.
"""
[docs]
@abc.abstractmethod
def delete(self, data_version: str) -> None:
"""Delete ``result`` keyed by ``data_version``."""
[docs]
@abc.abstractmethod
def delete_all(self) -> None:
"""Delete all stored results."""
[docs]
@abc.abstractmethod
def exists(self, data_version: str) -> bool:
"""boolean check if a ``result`` is found for ``data_version``
If True, ``.get()`` should successfully retrieve the ``result``.
"""
# TODO refactor the association between StoredResult, MetadataStore, and ResultStore
# to load data using the `DataLoader` class and kwargs instead of pickling the instantiated
# DataLoader object. This would be safer across Hamilton versions.
class StoredResult:
def __init__(
self,
value: Any,
expires_at=None,
saver=None,
loader=None,
):
self.value = value
self.expires_at = expires_at
self.saver = saver
self.loader = loader
@classmethod
def new(
cls,
value: Any,
expires_in: timedelta | None = None,
saver: DataSaver | None = None,
loader: DataLoader | None = None,
) -> "StoredResult":
if expires_in is not None and not isinstance(expires_in, timedelta):
expires_in = timedelta(seconds=expires_in)
# != operator on boolean is XOR
if bool(saver is not None) != bool(loader is not None):
raise ValueError(
"Must pass both `saver` and `loader` or neither. Currently received: "
f"`saver`: `{saver}`; `loader`: `{loader}`"
)
return cls(
value=value,
expires_at=(datetime.now(tz=timezone.utc) + expires_in) if expires_in else None,
saver=saver,
loader=loader,
)
@property
def expired(self) -> bool:
return self.expires_at is not None and datetime.now(tz=timezone.utc) >= self.expires_at
@property
def expires_in(self) -> int:
if self.expires_at:
return int(self.expires_at.timestamp() - datetime.now(tz=timezone.utc).timestamp())
return -1
def save(self) -> bytes:
"""Receives pickleable data or DataLoader to use to load the real data"""
if self.saver is not None:
self.saver.save_data(data=self.value)
to_pickle = self.loader
else:
to_pickle = self.value
return pickle.dumps(to_pickle)
@classmethod
def load(cls, raw: bytes) -> "StoredResult":
"""Reads the raw bytes from disk and sets `StoredResult.data`"""
loaded = pickle.loads(raw)
if isinstance(loaded, DataLoader):
loader = loaded
result, metadata = loader.load_data(None)
else:
loader = None
result = loaded
return StoredResult.new(value=result)