• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Pigweed Console progress bar task state."""
15
16from contextvars import ContextVar
17import copy
18from dataclasses import dataclass, field
19import signal
20from typing import Dict, Optional, Union
21
22from prompt_toolkit.application import get_app_or_none
23from prompt_toolkit.shortcuts import ProgressBar
24from prompt_toolkit.shortcuts.progress_bar import formatters
25
26from pw_console.progress_bar.progress_bar_impl import (
27    IterationsPerSecondIfNotHidden,
28    ProgressBarImpl,
29    TextIfNotHidden,
30    TimeLeftIfNotHidden,
31)
32from pw_console.progress_bar.progress_bar_task_counter import (
33    ProgressBarTaskCounter)
34from pw_console.style import generate_styles
35
36CUSTOM_FORMATTERS = [
37    formatters.Label(suffix=': '),
38    formatters.Rainbow(
39        formatters.Bar(start='|Pigw',
40                       end='|',
41                       sym_a='e',
42                       sym_b='d!',
43                       sym_c=' ')),
44    formatters.Text(' '),
45    formatters.Progress(),
46    formatters.Text(' ['),
47    formatters.Percentage(),
48    formatters.Text('] in '),
49    formatters.TimeElapsed(),
50    TextIfNotHidden(' ('),
51    IterationsPerSecondIfNotHidden(),
52    TextIfNotHidden('/s, eta: '),
53    TimeLeftIfNotHidden(),
54    TextIfNotHidden(')'),
55]
56
57
58def prompt_toolkit_app_running() -> bool:
59    existing_app = get_app_or_none()
60    if existing_app:
61        return True
62    return False
63
64
65@dataclass
66class ProgressBarState:
67    """Pigweed Console wide state for all repl progress bars.
68
69    An instance of this class is intended to be a global variable."""
70    tasks: Dict[str, ProgressBarTaskCounter] = field(default_factory=dict)
71    instance: Optional[Union[ProgressBar, ProgressBarImpl]] = None
72
73    def _install_sigint_handler(self) -> None:
74        """Add ctrl-c handling if not running inside pw_console"""
75        def handle_sigint(_signum, _frame):
76            # Shut down the ProgressBar prompt_toolkit application
77            prog_bar = self.instance
78            if prog_bar is not None and hasattr(prog_bar, '__exit__'):
79                prog_bar.__exit__()
80            raise KeyboardInterrupt
81
82        signal.signal(signal.SIGINT, handle_sigint)
83
84    def startup_progress_bar_impl(self):
85        prog_bar = self.instance
86        if not prog_bar:
87            if prompt_toolkit_app_running():
88                prog_bar = ProgressBarImpl(style=get_app_or_none().style,
89                                           formatters=CUSTOM_FORMATTERS)
90            else:
91                self._install_sigint_handler()
92                prog_bar = ProgressBar(style=generate_styles(),
93                                       formatters=CUSTOM_FORMATTERS)
94                # Start the ProgressBar prompt_toolkit application in a separate
95                # thread.
96                prog_bar.__enter__()
97            self.instance = prog_bar
98        return self.instance
99
100    def cleanup_finished_tasks(self) -> None:
101        for task_name in copy.copy(list(self.tasks.keys())):
102            task = self.tasks[task_name]
103            if task.completed or task.canceled:
104                ptc = task.prompt_toolkit_counter
105                self.tasks.pop(task_name, None)
106                if (self.instance and self.instance.counters
107                        and ptc in self.instance.counters):
108                    self.instance.counters.remove(ptc)
109
110    @property
111    def all_tasks_complete(self) -> bool:
112        tasks_complete = [
113            task.completed or task.canceled
114            for _task_name, task in self.tasks.items()
115        ]
116        self.cleanup_finished_tasks()
117        return all(tasks_complete)
118
119    def cancel_all_tasks(self):
120        self.tasks = {}
121        if self.instance is not None:
122            self.instance.counters = []
123
124    def get_container(self):
125        prog_bar = self.instance
126        if prog_bar is not None and hasattr(prog_bar, '__pt_container__'):
127            return prog_bar.__pt_container__()
128        return None
129
130
131TASKS_CONTEXTVAR = (ContextVar('pw_console_progress_bar_tasks',
132                               default=ProgressBarState()))
133