• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Pigweed Console progress bar functions."""
15from pw_console.progress_bar.progress_bar_state import TASKS_CONTEXTVAR
16from pw_console.progress_bar.progress_bar_task_counter import (
17    ProgressBarTaskCounter,
18)
19
20__all__ = [
21    'start_progress',
22    'update_progress',
23]
24
25
26def start_progress(task_name: str, total: int, hide_eta=False):
27    progress_state = TASKS_CONTEXTVAR.get()
28
29    progress_state.startup_progress_bar_impl()
30
31    assert progress_state.instance is not None
32
33    progress_state.tasks[task_name] = ProgressBarTaskCounter(
34        name=task_name,
35        total=total,
36        prompt_toolkit_counter=progress_state.instance(
37            range(total), label=task_name
38        ),
39    )
40    ptc = progress_state.tasks[task_name].prompt_toolkit_counter
41    ptc.hide_eta = hide_eta  # type: ignore
42
43
44def update_progress(
45    task_name: str, count=1, completed=False, canceled=False, new_total=None
46):
47    progress_state = TASKS_CONTEXTVAR.get()
48    # The caller may not actually get canceled and will continue trying to
49    # update after an interrupt.
50    if task_name not in progress_state.tasks:
51        return
52
53    # Take one action
54    if completed:
55        progress_state.tasks[task_name].mark_completed()
56    elif canceled:
57        progress_state.tasks[task_name].mark_canceled()
58    elif new_total:
59        progress_state.tasks[task_name].set_new_total(new_total)
60    else:
61        progress_state.tasks[task_name].update(count)
62
63    # Check if all tasks are complete
64    if (
65        progress_state.instance is not None
66        and progress_state.all_tasks_complete
67    ):
68        if hasattr(progress_state.instance, '__exit__'):
69            progress_state.instance.__exit__()
70