Source code for hamilton.plugins.h_tqdm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Collection
from typing import Any
import tqdm
from hamilton import graph_types
from hamilton.lifecycle import GraphExecutionHook, NodeExecutionHook
[docs]
class ProgressBar(
GraphExecutionHook,
NodeExecutionHook,
):
"""An adapter that uses tqdm to show progress bars for the graph execution.
Note: you need to have tqdm installed for this to work.
If you don't have it installed, you can install it with `pip install tqdm`
(or `pip install apache-hamilton[tqdm]` -- use quotes if you're using zsh).
.. code-block:: python
from hamilton.plugins import h_tqdm
dr = (
driver.Builder()
.with_config({})
.with_modules(some_modules)
.with_adapters(h_tqdm.ProgressBar(desc="DAG-NAME"))
.build()
)
# and then when you call .execute() or .materialize() you'll get a progress bar!
"""
[docs]
def __init__(self, desc: str = "Graph execution", max_node_name_width: int = 50, **kwargs):
"""Create a new Progress Bar adapter.
:param desc: The description to show in the progress bar. E.g. DAG Name is a good choice.
:param kwargs: Additional kwargs to pass to TQDM. See TQDM docs for more info.
:param node_name_target_width: the target width for the node name so that the progress bar is consistent. If this is None, it will take the longest, until it hits max_node_name_width.
"""
self.desc = desc
self.kwargs = kwargs
self.node_name_target_width = (
None # what we target padding for -- starts at None as we adjust.
)
self.max_node_name_width = max_node_name_width # what we cap the padding at.
self.progress_bar = None
def _get_node_name_display(self, node_name: str) -> str:
"""Gives the node name display given a max width and a node name. Max width could be DAG-dependent."""
out = (
node_name
if len(node_name) <= self.node_name_target_width
else node_name[: self.node_name_target_width - 3] + "..."
)
if len(out) < self.node_name_target_width:
out += " " * (self.node_name_target_width - len(out))
return out
[docs]
def run_before_graph_execution(
self,
*,
graph: graph_types.HamiltonGraph,
final_vars: list[str],
inputs: dict[str, Any],
overrides: dict[str, Any],
execution_path: Collection[str],
**future_kwargs: Any,
):
total_node_to_execute = len(execution_path)
max_node_name_length = min(
max([len(node) for node in execution_path]), self.max_node_name_width
)
if self.node_name_target_width is None:
self.node_name_target_width = max_node_name_length
self.progress_bar = tqdm.tqdm(
desc=self.desc, unit="funcs", total=total_node_to_execute, **self.kwargs
)
[docs]
def run_before_node_execution(
self,
*,
node_name: str,
node_tags: dict[str, Any],
node_kwargs: dict[str, Any],
node_return_type: type,
task_id: str | None,
**future_kwargs: Any,
):
name_display = self._get_node_name_display(node_name)
self.progress_bar.set_description_str(f"{self.desc} -> {name_display}")
[docs]
def run_after_node_execution(self, **future_kwargs):
self.progress_bar.update(1)
[docs]
def run_after_graph_execution(self, *, success: bool = True, **future_kwargs):
name_part = "Execution Complete!"
if len(name_part) > self.node_name_target_width:
padding = ""
else:
padding = " " * (self.node_name_target_width - len(name_part))
# Overrides are counted in `total` but never trigger run_after_node_execution,
# so on a successful run we top the bar up to 100% rather than leaving a gap.
if success and self.progress_bar.total is not None:
remaining = self.progress_bar.total - self.progress_bar.n
if remaining > 0:
self.progress_bar.update(remaining)
self.progress_bar.set_description_str(f"{self.desc} -> {name_part + padding}")
self.progress_bar.set_postfix({})
self.progress_bar.close()