Source code for hamilton.caching.stores.file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import inspect
import shutil
from pathlib import Path
from typing import Any
try:
from typing import override
except ImportError:
override = lambda x: x # noqa E731
from hamilton.io.data_adapters import DataLoader, DataSaver
from .base import ResultStore, StoredResult
[docs]
class FileResultStore(ResultStore):
def __init__(self, path: str, create_dir: bool = True) -> None:
self.path = Path(path)
self.create_dir = create_dir
if self.create_dir:
self.path.mkdir(exist_ok=True, parents=True)
def __getstate__(self) -> dict:
"""Serialize the `__init__` kwargs to pass in Parallelizable branches
when using multiprocessing.
"""
return {"path": str(self.path)}
@staticmethod
def _write_result(file_path: Path, stored_result: StoredResult) -> None:
file_path.write_bytes(stored_result.save())
@staticmethod
def _load_result_from_path(path: Path) -> StoredResult | None:
try:
data = path.read_bytes()
return StoredResult.load(data)
except FileNotFoundError:
return None
def _path_from_data_version(self, data_version: str) -> Path:
return self.path.joinpath(data_version)
def _materialized_path(self, data_version: str, saver_cls: DataSaver) -> Path:
# TODO allow a more flexible mechanism to specify file path extension
return self._path_from_data_version(data_version).with_suffix(f".{saver_cls.name()}")
[docs]
@override
def exists(self, data_version: str) -> bool:
result_path = self._path_from_data_version(data_version)
return result_path.exists()
[docs]
@override
def set(
self,
data_version: str,
result: Any,
saver_cls: DataSaver | None = None,
loader_cls: DataLoader | None = None,
) -> None:
# != operator on boolean is XOR
if bool(saver_cls is not None) != bool(loader_cls is not None):
raise ValueError(
"Must pass both `saver` and `loader` or neither. Currently received: "
f"`saver`: `{saver_cls}`; `loader`: `{loader_cls}`"
)
if saver_cls is not None:
# materialized_path
materialized_path = self._materialized_path(data_version, saver_cls)
saver_argspec = inspect.getfullargspec(saver_cls.__init__)
loader_argspec = inspect.getfullargspec(loader_cls.__init__)
if "file" in saver_argspec.args:
saver = saver_cls(file=str(materialized_path.absolute()))
elif "path" in saver_argspec.args:
saver = saver_cls(path=str(materialized_path.absolute()))
else:
raise ValueError(
f"Saver [{saver_cls.name()}] must have either `file` or `path` as an argument."
)
if "file" in loader_argspec.args:
loader = loader_cls(file=str(materialized_path.absolute()))
elif "path" in loader_argspec.args:
loader = loader_cls(path=str(materialized_path.absolute()))
else:
raise ValueError(
f"Loader [{loader_cls.name()}] must have either `file` or `path` as an argument."
)
else:
saver = None
loader = None
self.path.mkdir(exist_ok=True)
result_path = self._path_from_data_version(data_version)
stored_result = StoredResult.new(value=result, saver=saver, loader=loader)
self._write_result(result_path, stored_result)
[docs]
@override
def get(self, data_version: str) -> Any | None:
result_path = self._path_from_data_version(data_version)
stored_result = self._load_result_from_path(result_path)
if stored_result is None:
return None
return stored_result.value
[docs]
@override
def delete(self, data_version: str) -> None:
result_path = self._path_from_data_version(data_version)
result_path.unlink(missing_ok=True)
[docs]
@override
def delete_all(self) -> None:
shutil.rmtree(self.path)
self.path.mkdir(exist_ok=True)
def delete_expired(self) -> None:
for file_path in self.path.iterdir():
stored_result = self._load_result_from_path(file_path)
if stored_result and stored_result.expired:
file_path.unlink(missing_ok=True)