• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Fetch active Project Builder Context."""
15
16import asyncio
17import concurrent.futures
18from contextvars import ContextVar
19from datetime import datetime
20from dataclasses import dataclass, field
21from enum import Enum
22import logging
23import os
24import subprocess
25from typing import Callable, Dict, List, Optional, NoReturn, TYPE_CHECKING
26
27from prompt_toolkit.formatted_text import (
28    AnyFormattedText,
29    StyleAndTextTuples,
30)
31from prompt_toolkit.key_binding import KeyBindings
32from prompt_toolkit.layout.dimension import AnyDimension, D
33from prompt_toolkit.layout import (
34    AnyContainer,
35    DynamicContainer,
36    FormattedTextControl,
37    Window,
38)
39from prompt_toolkit.shortcuts import ProgressBar, ProgressBarCounter
40from prompt_toolkit.shortcuts.progress_bar.formatters import (
41    Formatter,
42    TimeElapsed,
43)
44from prompt_toolkit.shortcuts.progress_bar import formatters
45
46from pw_build.build_recipe import BuildRecipe
47
48if TYPE_CHECKING:
49    from pw_build.project_builder import ProjectBuilder
50
51_LOG = logging.getLogger('pw_build.watch')
52
53
54def _wait_for_terminate_then_kill(
55    proc: subprocess.Popen, timeout: int = 5
56) -> int:
57    """Wait for a process to end, then kill it if the timeout expires."""
58    returncode = 1
59    try:
60        returncode = proc.wait(timeout=timeout)
61    except subprocess.TimeoutExpired:
62        proc_command = proc.args
63        if isinstance(proc.args, list):
64            proc_command = ' '.join(proc.args)
65        _LOG.debug('Killing %s', proc_command)
66        proc.kill()
67    return returncode
68
69
70class ProjectBuilderState(Enum):
71    IDLE = 'IDLE'
72    BUILDING = 'BUILDING'
73    ABORT = 'ABORT'
74
75
76# pylint: disable=unused-argument
77# Prompt Toolkit progress bar formatter classes follow:
78class BuildStatus(Formatter):
79    """Return OK/FAIL status and last output line for each build."""
80
81    def __init__(self, ctx) -> None:
82        self.ctx = ctx
83
84    def format(
85        self,
86        progress_bar: ProgressBar,
87        progress: ProgressBarCounter,
88        width: int,
89    ) -> AnyFormattedText:
90        for cfg in self.ctx.recipes:
91            if cfg.display_name != progress.label:
92                continue
93
94            build_status: StyleAndTextTuples = []
95            build_status.append(
96                cfg.status.status_slug(restarting=self.ctx.restart_flag)
97            )
98            build_status.append(('', ' '))
99            build_status.extend(cfg.status.current_step_formatted())
100
101            return build_status
102
103        return [('', '')]
104
105    def get_width(  # pylint: disable=no-self-use
106        self, progress_bar: ProgressBar
107    ) -> AnyDimension:
108        return D()
109
110
111class TimeElapsedIfStarted(TimeElapsed):
112    """Display the elapsed time if the build has started."""
113
114    def __init__(self, ctx) -> None:
115        self.ctx = ctx
116
117    def format(
118        self,
119        progress_bar: ProgressBar,
120        progress: ProgressBarCounter,
121        width: int,
122    ) -> AnyFormattedText:
123        formatted_text: StyleAndTextTuples = [('', '')]
124        for cfg in self.ctx.recipes:
125            if cfg.display_name != progress.label:
126                continue
127            if cfg.status.started:
128                return super().format(progress_bar, progress, width)
129        return formatted_text
130
131
132# pylint: enable=unused-argument
133
134
135@dataclass
136class ProjectBuilderContext:  # pylint: disable=too-many-instance-attributes,too-many-public-methods
137    """Maintains the state of running builds and active subproccesses."""
138
139    current_state: ProjectBuilderState = ProjectBuilderState.IDLE
140    desired_state: ProjectBuilderState = ProjectBuilderState.BUILDING
141    procs: Dict[BuildRecipe, subprocess.Popen] = field(default_factory=dict)
142    recipes: List[BuildRecipe] = field(default_factory=list)
143
144    def __post_init__(self) -> None:
145        self.project_builder: Optional['ProjectBuilder'] = None
146
147        self.progress_bar_formatters = [
148            formatters.Text(' '),
149            formatters.Label(),
150            formatters.Text(' '),
151            BuildStatus(self),
152            formatters.Text(' '),
153            TimeElapsedIfStarted(self),
154            formatters.Text(' '),
155        ]
156
157        self._enter_callback: Optional[Callable] = None
158
159        key_bindings = KeyBindings()
160
161        @key_bindings.add('enter')
162        def _enter_pressed(_event):
163            """Run enter press function."""
164            if self._enter_callback:
165                self._enter_callback()
166
167        self.key_bindings = key_bindings
168
169        self.progress_bar: Optional[ProgressBar] = None
170
171        self._progress_bar_started: bool = False
172
173        self.bottom_toolbar: AnyFormattedText = None
174        self.horizontal_separator = '━'
175        self.title_bar_container: AnyContainer = Window(
176            char=self.horizontal_separator, height=1
177        )
178
179        self.using_fullscreen: bool = False
180        self.restart_flag: bool = False
181        self.ctrl_c_pressed: bool = False
182
183    def using_progress_bars(self) -> bool:
184        return bool(self.progress_bar) or self.using_fullscreen
185
186    def interrupted(self) -> bool:
187        return self.ctrl_c_pressed or self.restart_flag
188
189    def set_bottom_toolbar(self, text: AnyFormattedText) -> None:
190        self.bottom_toolbar = text
191
192    def set_enter_callback(self, callback: Callable) -> None:
193        self._enter_callback = callback
194
195    def ctrl_c_interrupt(self) -> None:
196        """Abort function for when using ProgressBars."""
197        self.ctrl_c_pressed = True
198        self.exit(1)
199
200    def startup_progress(self) -> None:
201        self.progress_bar = ProgressBar(
202            formatters=self.progress_bar_formatters,
203            key_bindings=self.key_bindings,
204            title=self.get_title_bar_text,
205            bottom_toolbar=self.bottom_toolbar,
206            cancel_callback=self.ctrl_c_interrupt,
207        )
208        self.progress_bar.__enter__()  # pylint: disable=unnecessary-dunder-call
209
210        self.create_title_bar_container()
211        self.progress_bar.app.layout.container.children[  # type: ignore
212            0
213        ] = DynamicContainer(lambda: self.title_bar_container)
214        self._progress_bar_started = True
215
216    def exit_progress(self) -> None:
217        if not self.progress_bar:
218            return
219        self.progress_bar.__exit__()  # pylint: disable=unnecessary-dunder-call
220
221    def clear_progress_scrollback(self) -> None:
222        if not self.progress_bar:
223            return
224        self.progress_bar._app_loop.call_soon_threadsafe(  # pylint: disable=protected-access
225            self.progress_bar.app.renderer.clear
226        )
227
228    def redraw_progress(self) -> None:
229        if not self.progress_bar:
230            return
231        if hasattr(self.progress_bar, 'app'):
232            self.progress_bar.invalidate()
233
234    def get_title_style(self) -> str:
235        if self.restart_flag:
236            return 'fg:ansiyellow'
237
238        # Assume passing
239        style = 'fg:ansigreen'
240
241        if self.current_state == ProjectBuilderState.BUILDING:
242            style = 'fg:ansiyellow'
243
244        for cfg in self.recipes:
245            if cfg.status.failed():
246                style = 'fg:ansired'
247
248        return style
249
250    def exit_code(self) -> int:
251        """Returns a 0 for success, 1 for fail."""
252        for cfg in self.recipes:
253            if cfg.status.failed():
254                return 1
255        return 0
256
257    def get_title_bar_text(
258        self, include_separators: bool = True
259    ) -> StyleAndTextTuples:
260        title = ''
261
262        fail_count = 0
263        done_count = 0
264        for cfg in self.recipes:
265            if cfg.status.failed():
266                fail_count += 1
267            if cfg.status.done:
268                done_count += 1
269
270        if self.restart_flag:
271            title = 'INTERRUPT'
272        elif fail_count > 0:
273            title = f'FAILED ({fail_count})'
274        elif self.current_state == ProjectBuilderState.IDLE and done_count > 0:
275            title = 'PASS'
276        else:
277            title = self.current_state.name
278
279        prefix = ''
280        if include_separators:
281            prefix += f'{self.horizontal_separator}{self.horizontal_separator} '
282
283        return [(self.get_title_style(), f'{prefix}{title} ')]
284
285    def create_title_bar_container(self) -> None:
286        title_text = FormattedTextControl(self.get_title_bar_text)
287        self.title_bar_container = Window(
288            title_text,
289            char=self.horizontal_separator,
290            height=1,
291            # Expand width to max available space
292            dont_extend_width=False,
293            style=self.get_title_style,
294        )
295
296    def add_progress_bars(self) -> None:
297        if not self._progress_bar_started:
298            self.startup_progress()
299        assert self.progress_bar
300        self.clear_progress_bars()
301        for cfg in self.recipes:
302            self.progress_bar(label=cfg.display_name, total=len(cfg.steps))
303
304    def clear_progress_bars(self) -> None:
305        if not self.progress_bar:
306            return
307        self.progress_bar.counters = []
308
309    def mark_progress_step_complete(self, recipe: BuildRecipe) -> None:
310        if not self.progress_bar:
311            return
312        for pbc in self.progress_bar.counters:
313            if pbc.label == recipe.display_name:
314                pbc.item_completed()
315                break
316
317    def mark_progress_done(self, recipe: BuildRecipe) -> None:
318        if not self.progress_bar:
319            return
320
321        for pbc in self.progress_bar.counters:
322            if pbc.label == recipe.display_name:
323                pbc.done = True
324                break
325
326    def mark_progress_started(self, recipe: BuildRecipe) -> None:
327        if not self.progress_bar:
328            return
329
330        for pbc in self.progress_bar.counters:
331            if pbc.label == recipe.display_name:
332                pbc.start_time = datetime.now()
333                break
334
335    def register_process(
336        self, recipe: BuildRecipe, proc: subprocess.Popen
337    ) -> None:
338        self.procs[recipe] = proc
339
340    def terminate_and_wait(
341        self,
342        exit_message: Optional[str] = None,
343    ) -> None:
344        """End a subproces either cleanly or with a kill signal."""
345        if self.is_idle() or self.should_abort():
346            return
347
348        self._signal_abort()
349
350        with concurrent.futures.ThreadPoolExecutor(
351            max_workers=len(self.procs)
352        ) as executor:
353            futures = []
354            for _recipe, proc in self.procs.items():
355                if proc is None:
356                    continue
357
358                proc_command = proc.args
359                if isinstance(proc.args, list):
360                    proc_command = ' '.join(proc.args)
361                _LOG.debug('Stopping: %s', proc_command)
362
363                futures.append(
364                    executor.submit(_wait_for_terminate_then_kill, proc)
365                )
366            for future in concurrent.futures.as_completed(futures):
367                future.result()
368
369        if exit_message:
370            _LOG.info(exit_message)
371        self.set_idle()
372
373    def _signal_abort(self) -> None:
374        self.desired_state = ProjectBuilderState.ABORT
375
376    def build_stopping(self) -> bool:
377        """Return True if the build is restarting or quitting."""
378        return self.should_abort() or self.interrupted()
379
380    def should_abort(self) -> bool:
381        """Return True if the build is restarting."""
382        return self.desired_state == ProjectBuilderState.ABORT
383
384    def is_building(self) -> bool:
385        return self.current_state == ProjectBuilderState.BUILDING
386
387    def is_idle(self) -> bool:
388        return self.current_state == ProjectBuilderState.IDLE
389
390    def set_project_builder(self, project_builder) -> None:
391        self.project_builder = project_builder
392        self.recipes = project_builder.build_recipes
393
394    def set_idle(self) -> None:
395        self.current_state = ProjectBuilderState.IDLE
396        self.desired_state = ProjectBuilderState.IDLE
397
398    def set_building(self) -> None:
399        self.restart_flag = False
400        self.current_state = ProjectBuilderState.BUILDING
401        self.desired_state = ProjectBuilderState.BUILDING
402
403    def restore_stdout_logging(self) -> None:  # pylint: disable=no-self-use
404        if not self.using_progress_bars():
405            return
406
407        # Restore logging to STDOUT
408        stdout_handler = logging.StreamHandler()
409        if self.project_builder:
410            stdout_handler.setLevel(self.project_builder.default_log_level)
411        else:
412            stdout_handler.setLevel(logging.INFO)
413        root_logger = logging.getLogger()
414        if self.project_builder and self.project_builder.stdout_proxy:
415            self.project_builder.stdout_proxy.flush()
416            self.project_builder.stdout_proxy.close()
417        root_logger.addHandler(stdout_handler)
418
419    def restore_logging_and_shutdown(
420        self,
421        log_after_shutdown: Optional[Callable[[], None]] = None,
422    ) -> None:
423        self.restore_stdout_logging()
424        _LOG.warning('Abort signal recieved, stopping processes...')
425        if log_after_shutdown:
426            log_after_shutdown()
427        self.terminate_and_wait()
428        # Flush all log handlers
429        # logging.shutdown()
430
431    def exit(
432        self,
433        exit_code: int = 1,
434        log_after_shutdown: Optional[Callable[[], None]] = None,
435    ) -> None:
436        """Exit function called when the user presses ctrl-c."""
437
438        # Note: The correct way to exit Python is via sys.exit() however this
439        # takes a number of seconds when running pw_watch with multiple parallel
440        # builds. Instead, this function calls os._exit() to shutdown
441        # immediately. This is similar to `pw_watch.watch._exit`:
442        # https://cs.opensource.google/pigweed/pigweed/+/main:pw_watch/py/pw_watch/watch.py?q=_exit.code
443
444        if not self.progress_bar:
445            self.restore_logging_and_shutdown(log_after_shutdown)
446            logging.shutdown()
447            os._exit(exit_code)  # pylint: disable=protected-access
448
449        # Shut everything down after the progress_bar exits.
450        def _really_exit(future: asyncio.Future) -> NoReturn:
451            self.restore_logging_and_shutdown(log_after_shutdown)
452            logging.shutdown()
453            os._exit(future.result())  # pylint: disable=protected-access
454
455        if self.progress_bar.app.future:
456            self.progress_bar.app.future.add_done_callback(_really_exit)
457            self.progress_bar.app.exit(result=exit_code)  # type: ignore
458
459
460PROJECT_BUILDER_CONTEXTVAR = ContextVar(
461    'pw_build_project_builder_state', default=ProjectBuilderContext()
462)
463
464
465def get_project_builder_context():
466    return PROJECT_BUILDER_CONTEXTVAR.get()
467