# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import pickle
import warnings
from typing import Any
import mlflow
import mlflow.data
from hamilton import graph_types
from hamilton.lifecycle import GraphConstructionHook, GraphExecutionHook, NodeExecutionHook
# silence odd ongoing MLFlow issue that spams warnings
# GitHub Issue https://github.com/mlflow/mlflow/issues/8605
warnings.filterwarnings("ignore", category=UserWarning)
FIGURE_TYPES = []
try:
import matplotlib.figure
FIGURE_TYPES.append(matplotlib.figure.Figure)
except ImportError:
pass
try:
import plotly.graph_objects
FIGURE_TYPES.append(plotly.graph_objects.Figure)
except ImportError:
pass
logger = logging.getLogger(__name__)
def get_path_from_metadata(metadata: dict) -> str | None:
"""Retrieve the `path` attribute from DataSaver output metadata"""
path = None
if "path" in metadata:
path = metadata["path"]
elif "file_metadata" in metadata:
path = metadata["file_metadata"]["path"]
return path
# NOTE `mlflow.client.MLFlowClient` is preferred to top-level `mlflow.` methods in MLFlowTracker
# because the latter relies on hard-to-debug global variables. Yet, we set an `active_run` by using
# `mlflow.start_run()` in pre_graph_execution to ensure the user-specified MLFlow code
# and MLFlow materializers log metrics and models to the same run as the MLFlowTracker
[docs]
class MLFlowTracker(
NodeExecutionHook,
GraphExecutionHook,
GraphConstructionHook,
):
"""Driver adapter logging Hamilton execution results to an MLFlow server."""
[docs]
def __init__(
self,
tracking_uri: str | None = None,
registry_uri: str | None = None,
artifact_location: str | None = None,
experiment_name: str = "Hamilton",
experiment_tags: dict | None = None,
experiment_description: str | None = None,
run_id: str | None = None,
run_name: str | None = None,
run_tags: dict | None = None,
run_description: str | None = None,
log_system_metrics: bool = False,
):
"""Configure the MLFlow client and experiment for the lifetime of the tracker
:param tracking_uri: Destination of the logged artifacts and metadata. It can be a filesystem, database, or server. [reference](https://mlflow.org/docs/latest/getting-started/tracking-server-overview/index.html)
:param registry_uri: Destination of the registered models. By default it's the same as the tracking destination, but they can be different. [reference](https://mlflow.org/docs/latest/getting-started/registering-first-model/index.html)
:param artifact_location: Root path on tracking server where experiment is stored
:param experiment_name: MLFlow experiment name used to group runs.
:param experiment_tags: Tags to query experiments programmatically (not displayed).
:param experiment_description: Description of the experiment displayed
:param run_id: Run id to log to an existing run (every execution logs to the same run)
:param run_name: Run name displayed and used to query runs. You can have multiple runs with the same name but different run ids.
:param run_tags: Tags to query runs and appears as columns in the UI for filtering and grouping. It automatically includes serializable inputs and Driver config.
:param run_description: Description of the run displayed
:param log_system_metrics: Log system metrics to display (requires additonal dependencies)
"""
self.client = mlflow.client.MlflowClient(tracking_uri, registry_uri)
# experiment setup
experiment_tags = experiment_tags or {}
if experiment_description:
# mlflow.note.content is the description field
experiment_tags["mlflow.note.content"] = experiment_description
# TODO link HamiltonTracker project and MLFlowTracker experiment
experiment = self.client.get_experiment_by_name(experiment_name)
if experiment:
experiment_id = experiment.experiment_id
# update tags and description of an existing experiment
if experiment_tags:
for k, v in experiment_tags.items():
self.client.set_experiment_tag(experiment_id, key=k, value=v)
# create an experiment
else:
experiment_id = self.client.create_experiment(
name=experiment_name,
artifact_location=artifact_location,
tags=experiment_tags,
)
self.experiment_id = experiment_id
# run setup
# TODO link HamiltonTracker and MLFlowTracker run ids
self.mlflow_run_id = run_id
self.run_name = run_name
self.run_tags = run_tags or {}
if run_description:
# mlflow.note.content is the description field
self.run_tags["mlflow.note.content"] = run_description
self.log_system_metrics = log_system_metrics
[docs]
def run_after_graph_construction(self, *, config: dict[str, Any], **kwargs):
"""Store the Driver config before creating the graph"""
self.config = config
[docs]
def run_before_graph_execution(
self,
*,
run_id: str,
final_vars: list[str],
inputs: dict[str, Any],
graph: graph_types.HamiltonGraph,
**kwargs,
):
"""Create and start MLFlow run. Log graph version, run_id, inputs, overrides"""
# add Hamilton metadata to run tags
run_tags = self.run_tags
run_tags["hamilton_run_id"] = run_id # the Hamilton run_id
run_tags["code_version"] = graph.version
# create Hamilton run
self.run = self.client.create_run(
experiment_id=self.experiment_id,
tags=run_tags,
run_name=self.run_name,
)
self.run_id = self.run.info.run_id
# start run to set `active_run` and allow user-defined callbacks and materializers
# to log to the same run as the HamiltonTracker
mlflow.start_run(
run_id=self.run_id,
experiment_id=self.experiment_id,
tags=run_tags,
log_system_metrics=self.log_system_metrics,
)
# log config to artifacts
self.client.log_dict(self.run_id, self.config, "config.json")
# log HamiltonGraph to reproduce the run
self.graph = graph
graph_as_json = {n.name: n.as_dict() for n in graph.nodes}
self.client.log_dict(self.run_id, graph_as_json, "hamilton_graph.json")
# log config and inputs as `param` which creates columns in the UI to filter runs
# `log_param()` accepts `value: Any` and will stringify complex objects
for value_sets in [self.config, inputs]:
if value_sets is None:
continue
for node_name, value in value_sets.items():
self.client.log_param(self.run_id, key=node_name, value=value)
self.final_vars = final_vars
# TODO log DataLoaders as MLFlow datasets
[docs]
def run_after_node_execution(
self,
*,
node_name: str,
node_return_type: type,
node_tags: dict,
node_kwargs: dict,
result: Any,
**kwargs,
):
"""Log materializers and final vars as artifacts"""
# log DataSavers as artifacts
# TODO refactor if/else as `handle_materializers()`
if node_tags.get("hamilton.data_saver") is True:
# don't log mlflow materializers as artifact since they already create models
# instead, use the Materializer metadata to add metadata to registered models
if node_tags["hamilton.data_saver.sink"] == "mlflow":
# skip if not registered model
if "registered_model" not in result.keys():
return
# get the registered model name (param of MLFlowModelSaver)
model_name = result["registered_model"]["name"]
version = result["registered_model"]["version"]
materializer_node = self.graph[node_name]
# get the "materialized node" defining the model
materialized_node = self.graph[materializer_node.required_dependencies.pop()]
# add the materialized node docstring as description
# registered models have multiple versions
self.client.update_registered_model(model_name, materialized_node.documentation)
self.client.update_model_version(
model_name, version, materialized_node.documentation
)
# add node name as tag
self.client.set_registered_model_tag(
model_name, key="node_name", value=materialized_node.name
)
self.client.set_model_version_tag(
model_name, version, key="node_name", value=materialized_node.name
)
# add origin function name as tag
self.client.set_registered_model_tag(
model_name,
key="function_name",
value=materialized_node.originating_functions[0].__name__,
)
self.client.set_model_version_tag(
model_name,
version,
key="function_name",
value=materialized_node.originating_functions[0].__name__,
)
# add the materialized node @tag values as tags
for k, v in materialized_node.tags.items():
# skip internal Hamilton tags
if "hamilton." in k:
continue
self.client.set_registered_model_tag(model_name, key=k, value=v)
self.client.set_model_version_tag(model_name, version, key=k, value=v)
# TODO automatically collect model input signature; maybe simpler from user code
# special case for matplotlib and plotly
# log materialized figure. Allows great degree of control over rendering format
# and also save interactive plotly visualization as HTML
elif node_tags["hamilton.data_saver.sink"] in ["plt", "plotly"]:
materializer_node = self.graph[node_name]
materialized_node = self.graph[materializer_node.required_dependencies.pop()]
figure = node_kwargs[materialized_node.name]
path = get_path_from_metadata(result)
if path:
self.client.log_figure(self.run_id, figure, path)
else:
logger.warning(
f"Materialization result from node={node_name} has no recordable path: {result}. Materializer must have either "
f"'path' or 'file_metadata' keys."
)
else:
# log the materializer path as an artifact
path = get_path_from_metadata(result)
if path:
self.client.log_artifact(self.run_id, path, node_name)
else:
logger.warning(
f"Materialization result from node={node_name} has no recordable path: {result}. Materializer must have either "
f"'path' or 'file_metadata' keys."
)
return
# log final_vars as artifacts
if node_name not in self.final_vars:
return
# TODO refactor if/else as `handle_final_vars()`
# log float and int as metrics
if node_return_type in [float, int]:
self.client.log_metric(self.run_id, key=node_name, value=float(result))
# log str as text in .txt format
elif isinstance(node_return_type, str):
file_path = f"{node_name}.txt"
with open(file_path, "w") as f:
f.write(result)
self.client.log_text(self.run_id, result, file_path)
# log_dict (JSON) dictionary types; pickle if not json-serializable
elif isinstance(node_return_type, dict):
try:
file_path = f"{node_name}.json"
self.client.log_dict(self.run_id, result, file_path)
# not json-serializable
except TypeError:
file_path = f"{node_name}.pickle"
with open(file_path, "wb") as f:
pickle.dump(result, file=f)
self.client.log_dict(self.run_id, result, file_path)
# this puts less burden on users by not having to define materializers
# for viz, but less control over rendering format
elif node_return_type in FIGURE_TYPES:
file_path = f"{node_name}.png"
self.client.log_figure(self.run_id, result, file_path)
# default to log_artifact in .pickle format
else:
file_path = f"{node_name}.pickle"
with open(file_path, "wb") as f:
pickle.dump(result, f)
self.client.log_dict(self.run_id, result, file_path)
[docs]
def run_after_graph_execution(self, success: bool, *args, **kwargs):
"""End the MLFlow run"""
# `status` is an enum value of mlflow.entities.RunStatus
status = "FINISHED" if success else "FAILED"
self.client.set_terminated(self.run_id, status=status)
mlflow.end_run(status=status)
[docs]
def run_before_node_execution(self, *args, **kwargs):
"""Placeholder required to subclass NodeExecutionHook"""