Source code for hamilton.plugins.h_rich
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Collection
from typing import Any
import rich.progress
from hamilton.execution.grouping import NodeGroupPurpose
try:
from typing import override
except ImportError:
override = lambda x: x # noqa E731
import rich
from rich.progress import Progress
from hamilton.lifecycle import (
GraphExecutionHook,
NodeExecutionHook,
TaskExecutionHook,
TaskGroupingHook,
)
[docs]
class RichProgressBar(TaskExecutionHook, TaskGroupingHook, GraphExecutionHook, NodeExecutionHook):
"""An adapter that uses rich to show simple progress bars for the graph execution.
Note: you need to have rich installed for this to work. If you don't have it installed, you can
install it with `pip install rich` (or `pip install apache-hamilton[rich]` -- use quotes if you're
using zsh).
.. code-block:: python
from hamilton import driver
from hamilton.plugins import h_rich
dr = (
driver.Builder()
.with_config({})
.with_modules(some_modules)
.with_adapters(h_rich.RichProgressBar())
.build()
)
and then when you call .execute() or .materialize() you'll get a progress bar!
Additionally, this progress bar will also work with task-based execution, showing the progress
of overall execution as well as the tasks within a parallelized group.
.. code-block:: python
from hamilton import driver
from hamilton.execution import executors
from hamilton.plugins import h_rich
dr = (
driver.Builder()
.with_modules(__main__)
.enable_dynamic_execution(allow_experimental_mode=True)
.with_adapters(RichProgressBar())
.with_local_executor(executors.SynchronousLocalTaskExecutor())
.with_remote_executor(executors.SynchronousLocalTaskExecutor())
.build()
)
"""
[docs]
def __init__(
self,
run_desc: str = "",
collect_desc: str = "",
columns: list[str | rich.progress.ProgressColumn] | None = None,
**kwargs,
) -> None:
"""Create a new Rich Progress Bar adapter.
:param run_desc: The description to show for the running phase.
:param collect_desc: The description to show for the collecting phase (if applicable).
:param columns: Column configuration for the progress bar. See rich docs for more info.
:param kwargs: Additional kwargs to pass to rich.progress.Progress. See rich docs for more info.
"""
self._group_desc = run_desc or "Running:"
self._expand_desc = collect_desc or "Collecting:"
columns = columns or []
self._progress = Progress(*columns, **kwargs)
self._task_based = False
[docs]
@override
def run_before_graph_execution(self, *, execution_path: Collection[str], **kwargs: Any):
self._progress.add_task(self._group_desc, total=len(execution_path))
self._progress.start()
[docs]
@override
def run_after_graph_execution(self, **kwargs: Any):
self._progress.stop() # in case progress thread is lagging
[docs]
@override
def run_after_task_grouping(self, *, task_ids: list[str], **kwargs):
# Change the total of the task group to the number of tasks in the group
self._progress.update(self._progress.task_ids[0], total=len(task_ids))
self._task_based = True
[docs]
@override
def run_after_task_expansion(self, *, parameters: dict[str, Any], **kwargs):
self._progress.add_task(self._expand_desc, total=len(parameters))
[docs]
@override
def run_before_task_execution(self, *, purpose: NodeGroupPurpose, **kwargs):
if purpose == NodeGroupPurpose.GATHER:
self._progress.advance(self._progress.task_ids[0])
self._progress.stop_task(self._progress.task_ids[-1])
[docs]
@override
def run_after_task_execution(self, *, purpose: NodeGroupPurpose, **kwargs):
if purpose == NodeGroupPurpose.EXECUTE_BLOCK:
self._progress.advance(self._progress.task_ids[-1])
else:
self._progress.advance(self._progress.task_ids[0])
[docs]
@override
def run_before_node_execution(self, **kwargs):
pass
[docs]
@override
def run_after_node_execution(self, **kwargs):
if not self._task_based:
self._progress.advance(self._progress.task_ids[0])