Source code for hamilton.caching.stores.sqlite

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import pathlib
import sqlite3
import threading

from hamilton.caching.cache_key import decode_key
from hamilton.caching.stores.base import MetadataStore


[docs] class SQLiteMetadataStore(MetadataStore): def __init__( self, path: str, connection_kwargs: dict | None = None, ) -> None: self._directory = pathlib.Path(path).resolve() self._directory.mkdir(parents=True, exist_ok=True) self._path = self._directory.joinpath("metadata_store").with_suffix(".db") self.connection_kwargs: dict = connection_kwargs or {} self._thread_local = threading.local() # creating tables at `__init__` prevents other methods from encountering # `sqlite3.OperationalError` because tables are missing. self._create_tables_if_not_exists() def __getstate__(self) -> dict: """Serialized `__init__` arguments required to initialize the MetadataStore in a new thread or process. """ state = {} # NOTE kwarg `path` is not equivalent to `self._path` state["path"] = self._directory state["connection_kwargs"] = self.connection_kwargs return state def _get_connection(self) -> sqlite3.Connection: if not hasattr(self._thread_local, "connection"): self._thread_local.connection = sqlite3.connect( str(self._path), check_same_thread=False, **self.connection_kwargs ) return self._thread_local.connection def _close_connection(self) -> None: if hasattr(self._thread_local, "connection"): self._thread_local.connection.close() del self._thread_local.connection @property def connection(self) -> sqlite3.Connection: """Connection to the SQLite database.""" return self._get_connection() def __del__(self): """Close the SQLite connection when the object is deleted""" self._close_connection() def _create_tables_if_not_exists(self) -> None: """Create the tables necessary for the cache: run_ids: queue of run_ids, ordered by start time. history: queue of executed node; allows to query "latest" execution of a node cache_metadata: information to determine if a node needs to be computed or not In the table ``cache_metadata``, the ``cache_key`` is unique whereas ``history`` allows duplicate. """ cur = self.connection.cursor() cur.execute( """\ CREATE TABLE IF NOT EXISTS run_ids ( id INTEGER PRIMARY KEY AUTOINCREMENT, run_id TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """ ) cur.execute( """\ CREATE TABLE IF NOT EXISTS history ( id INTEGER PRIMARY KEY AUTOINCREMENT, cache_key TEXT, run_id TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (cache_key) REFERENCES cache_metadata(cache_key) ); """ ) cur.execute( """\ CREATE TABLE IF NOT EXISTS cache_metadata ( cache_key TEXT PRIMARY KEY, data_version TEXT NOT NULL, node_name TEXT NOT NULL, code_version TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (cache_key) REFERENCES history(cache_key) ); """ ) self.connection.commit()
[docs] def initialize(self, run_id) -> None: """Call initialize when starting a run. This will create database tables if necessary. """ cur = self.connection.cursor() cur.execute("INSERT INTO run_ids (run_id) VALUES (?)", (run_id,)) self.connection.commit()
def __len__(self) -> int: """Number of entries in cache_metadata""" cur = self.connection.cursor() cur.execute("SELECT COUNT(*) FROM cache_metadata") return cur.fetchone()[0]
[docs] def set( self, *, cache_key: str, data_version: str, run_id: str, node_name: str = None, code_version: str = None, **kwargs, ) -> None: cur = self.connection.cursor() # if the caller of ``.set()`` directly provides the ``node_name`` and ``code_version``, # we can skip the decoding step. if (node_name is None) or (code_version is None): try: decoded_key = decode_key(cache_key) except BaseException as e: raise ValueError( f"Failed decoding the cache_key: {cache_key}.\n", "The `cache_key` must be created by `hamilton.caching.cache_key.create_cache_key()` ", "if `node_name` and `code_version` are not provided.", ) from e node_name = decoded_key["node_name"] code_version = decoded_key["code_version"] cur.execute("INSERT INTO history (cache_key, run_id) VALUES (?, ?)", (cache_key, run_id)) cur.execute( """\ INSERT OR IGNORE INTO cache_metadata ( cache_key, node_name, code_version, data_version ) VALUES (?, ?, ?, ?) """, (cache_key, node_name, code_version, data_version), ) self.connection.commit()
[docs] def get(self, cache_key: str) -> str | None: cur = self.connection.cursor() cur.execute( """\ SELECT data_version FROM cache_metadata WHERE cache_key = ? """, (cache_key,), ) result = cur.fetchone() if result is None: data_version = None else: data_version = result[0] return data_version
[docs] def delete(self, cache_key: str) -> None: """Delete metadata associated with ``cache_key``.""" cur = self.connection.cursor() cur.execute("DELETE FROM cache_metadata WHERE cache_key = ?", (cache_key,)) self.connection.commit()
[docs] def delete_all(self) -> None: """Delete all existing tables from the database""" cur = self.connection.cursor() for table_name in ["run_ids", "history", "cache_metadata"]: cur.execute(f"DROP TABLE IF EXISTS {table_name};") self.connection.commit()
[docs] def exists(self, cache_key: str) -> bool: """boolean check if a ``data_version`` is found for ``cache_key`` If True, ``.get()`` should successfully retrieve the ``data_version``. """ cur = self.connection.cursor() cur.execute("SELECT cache_key FROM cache_metadata WHERE cache_key = ?", (cache_key,)) result = cur.fetchone() return result is not None
[docs] def get_run_ids(self) -> list[str]: """Return a list of run ids, sorted from oldest to newest start time.""" cur = self.connection.cursor() cur.execute("SELECT run_id FROM run_ids ORDER BY id") result = cur.fetchall() return [r[0] for r in result]
def _run_exists(self, run_id: str) -> bool: """Returns True if a run was initialized with ``run_id``, even if the run recorded no node executions. """ cur = self.connection.cursor() cur.execute( """\ SELECT EXISTS( SELECT 1 FROM run_ids WHERE run_id = ? ) """, (run_id,), ) result = cur.fetchone() # SELECT EXISTS returns 1 for True, i.e., `run_id` is found return result[0] == 1
[docs] def get_run(self, run_id: str) -> list[dict]: """Return a list of node metadata associated with a run. :param run_id: ID of the run to retrieve :return: List of node metadata which includes ``cache_key``, ``data_version``, ``node_name``, and ``code_version``. The list can be empty if a run was initialized but no nodes were executed. :raises IndexError: if the ``run_id`` is not found in metadata store. """ cur = self.connection.cursor() if self._run_exists(run_id) is False: raise IndexError(f"`run_id` not found in table `run_ids`: {run_id}") cur.execute( """\ SELECT cache_metadata.cache_key, cache_metadata.data_version, cache_metadata.node_name, cache_metadata.code_version FROM history JOIN cache_metadata ON history.cache_key = cache_metadata.cache_key WHERE history.run_id = ? """, (run_id,), ) results = cur.fetchall() return [ dict( cache_key=cache_key, data_version=data_version, node_name=node_name, code_version=code_version, ) for cache_key, data_version, node_name, code_version in results ]