• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Pigweed Console progress bar task state."""
15
16from contextvars import ContextVar
17import copy
18from dataclasses import dataclass, field
19import signal
20from typing import Dict, Optional, Union
21
22from prompt_toolkit.application import get_app_or_none
23from prompt_toolkit.shortcuts import ProgressBar
24from prompt_toolkit.shortcuts.progress_bar import formatters
25
26from pw_console.progress_bar.progress_bar_impl import (
27    IterationsPerSecondIfNotHidden,
28    ProgressBarImpl,
29    TextIfNotHidden,
30    TimeLeftIfNotHidden,
31)
32from pw_console.progress_bar.progress_bar_task_counter import (
33    ProgressBarTaskCounter,
34)
35from pw_console.style import generate_styles
36
37CUSTOM_FORMATTERS = [
38    formatters.Label(suffix=': '),
39    formatters.Rainbow(
40        formatters.Bar(start='|Pigw', end='|', sym_a='e', sym_b='d!', sym_c=' ')
41    ),
42    formatters.Text(' '),
43    formatters.Progress(),
44    formatters.Text(' ['),
45    formatters.Percentage(),
46    formatters.Text('] in '),
47    formatters.TimeElapsed(),
48    TextIfNotHidden(' ('),
49    IterationsPerSecondIfNotHidden(),
50    TextIfNotHidden('/s, eta: '),
51    TimeLeftIfNotHidden(),
52    TextIfNotHidden(')'),
53]
54
55
56def prompt_toolkit_app_running() -> bool:
57    existing_app = get_app_or_none()
58    if existing_app:
59        return True
60    return False
61
62
63@dataclass
64class ProgressBarState:
65    """Pigweed Console wide state for all repl progress bars.
66
67    An instance of this class is intended to be a global variable."""
68
69    tasks: Dict[str, ProgressBarTaskCounter] = field(default_factory=dict)
70    instance: Optional[Union[ProgressBar, ProgressBarImpl]] = None
71
72    def _install_sigint_handler(self) -> None:
73        """Add ctrl-c handling if not running inside pw_console"""
74
75        def handle_sigint(_signum, _frame):
76            # Shut down the ProgressBar prompt_toolkit application
77            prog_bar = self.instance
78            if prog_bar is not None and hasattr(prog_bar, '__exit__'):
79                prog_bar.__exit__()  # pylint: disable=unnecessary-dunder-call
80            raise KeyboardInterrupt
81
82        signal.signal(signal.SIGINT, handle_sigint)
83
84    def startup_progress_bar_impl(self):
85        prog_bar = self.instance
86        if not prog_bar:
87            if prompt_toolkit_app_running():
88                prog_bar = ProgressBarImpl(
89                    style=get_app_or_none().style, formatters=CUSTOM_FORMATTERS
90                )
91            else:
92                self._install_sigint_handler()
93                prog_bar = ProgressBar(
94                    style=generate_styles(), formatters=CUSTOM_FORMATTERS
95                )
96                # Start the ProgressBar prompt_toolkit application in a separate
97                # thread.
98                prog_bar.__enter__()  # pylint: disable=unnecessary-dunder-call
99            self.instance = prog_bar
100        return self.instance
101
102    def cleanup_finished_tasks(self) -> None:
103        for task_name in copy.copy(list(self.tasks.keys())):
104            task = self.tasks[task_name]
105            if task.completed or task.canceled:
106                ptc = task.prompt_toolkit_counter
107                self.tasks.pop(task_name, None)
108                if (
109                    self.instance
110                    and self.instance.counters
111                    and ptc in self.instance.counters
112                ):
113                    self.instance.counters.remove(ptc)
114
115    @property
116    def all_tasks_complete(self) -> bool:
117        tasks_complete = [
118            task.completed or task.canceled
119            for _task_name, task in self.tasks.items()
120        ]
121        self.cleanup_finished_tasks()
122        return all(tasks_complete)
123
124    def cancel_all_tasks(self):
125        self.tasks = {}
126        if self.instance is not None:
127            self.instance.counters = []
128
129    def get_container(self):
130        prog_bar = self.instance
131        if prog_bar is not None and hasattr(prog_bar, '__pt_container__'):
132            return prog_bar.__pt_container__()
133        return None
134
135
136TASKS_CONTEXTVAR = ContextVar(
137    'pw_console_progress_bar_tasks', default=ProgressBarState()
138)
139