• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Device classes to interact with targets via RPC."""
15
16import datetime
17import logging
18from pathlib import Path
19from types import ModuleType
20from typing import Any, Callable, List, Union, Optional
21
22from pw_hdlc.rpc import HdlcRpcClient, default_channels
23from pw_log_tokenized import FormatStringWithMetadata
24from pw_log.proto import log_pb2
25from pw_metric import metric_parser
26from pw_rpc import callback_client, console_tools
27from pw_status import Status
28from pw_thread.thread_analyzer import ThreadSnapshotAnalyzer
29from pw_thread_protos import thread_pb2
30from pw_tokenizer import detokenize
31from pw_tokenizer.proto import decode_optionally_tokenized
32from pw_unit_test.rpc import run_tests as pw_unit_test_run_tests
33
34# Internal log for troubleshooting this tool (the console).
35_LOG = logging.getLogger('tools')
36DEFAULT_DEVICE_LOGGER = logging.getLogger('rpc_device')
37
38
39class Device:
40    """Represents an RPC Client for a device running a Pigweed target.
41
42    The target must have and RPC support, RPC logging.
43    Note: use this class as a base for specialized device representations.
44    """
45
46    def __init__(
47        self,
48        channel_id: int,
49        read,
50        write,
51        proto_library: List[Union[ModuleType, Path]],
52        detokenizer: Optional[detokenize.Detokenizer],
53        timestamp_decoder: Optional[Callable[[int], str]],
54        rpc_timeout_s: float = 5,
55        use_rpc_logging: bool = True,
56    ):
57        self.channel_id = channel_id
58        self.protos = proto_library
59        self.detokenizer = detokenizer
60        self.rpc_timeout_s = rpc_timeout_s
61
62        self.logger = DEFAULT_DEVICE_LOGGER
63        self.logger.setLevel(logging.DEBUG)  # Allow all device logs through.
64        self.timestamp_decoder = timestamp_decoder
65        self._expected_log_sequence_id = 0
66
67        callback_client_impl = callback_client.Impl(
68            default_unary_timeout_s=self.rpc_timeout_s,
69            default_stream_timeout_s=None,
70        )
71
72        def detokenize_and_log_output(data: bytes, _detokenizer=None):
73            log_messages = data.decode(
74                encoding='utf-8', errors='surrogateescape'
75            )
76
77            if self.detokenizer:
78                log_messages = decode_optionally_tokenized(
79                    self.detokenizer, data
80                )
81
82            for line in log_messages.splitlines():
83                self.logger.info(line)
84
85        self.client = HdlcRpcClient(
86            read,
87            self.protos,
88            default_channels(write),
89            detokenize_and_log_output,
90            client_impl=callback_client_impl,
91        )
92
93        if use_rpc_logging:
94            # Start listening to logs as soon as possible.
95            self.listen_to_log_stream()
96
97    def info(self) -> console_tools.ClientInfo:
98        return console_tools.ClientInfo('device', self.rpcs, self.client.client)
99
100    @property
101    def rpcs(self) -> Any:
102        """Returns an object for accessing services on the specified channel."""
103        return next(iter(self.client.client.channels())).rpcs
104
105    def run_tests(self, timeout_s: Optional[float] = 5) -> bool:
106        """Runs the unit tests on this device."""
107        return pw_unit_test_run_tests(self.rpcs, timeout_s=timeout_s)
108
109    def listen_to_log_stream(self):
110        """Opens a log RPC for the device's unrequested log stream.
111
112        The RPCs remain open until the server cancels or closes them, either
113        with a response or error packet.
114        """
115        self.rpcs.pw.log.Logs.Listen.open(
116            on_next=lambda _, log_entries_proto: self._log_entries_proto_parser(
117                log_entries_proto
118            ),
119            on_completed=lambda _, status: _LOG.info(
120                'Log stream completed with status: %s', status
121            ),
122            on_error=lambda _, error: self._handle_log_stream_error(error),
123        )
124
125    def _handle_log_stream_error(self, error: Status):
126        """Resets the log stream RPC on error to avoid losing logs."""
127        _LOG.error('Log stream error: %s', error)
128
129        # Only re-request logs if the RPC was not cancelled by the client.
130        if error != Status.CANCELLED:
131            self.listen_to_log_stream()
132
133    def _handle_log_drop_count(self, drop_count: int, reason: str):
134        log_text = 'log' if drop_count == 1 else 'logs'
135        message = f'Dropped {drop_count} {log_text} due to {reason}'
136        self._emit_device_log(logging.WARNING, '', '', message)
137
138    def _check_for_dropped_logs(self, log_entries_proto: log_pb2.LogEntries):
139        # Count log messages received that don't use the dropped field.
140        messages_received = sum(
141            1 if not log_proto.dropped else 0
142            for log_proto in log_entries_proto.entries
143        )
144        dropped_log_count = (
145            log_entries_proto.first_entry_sequence_id
146            - self._expected_log_sequence_id
147        )
148        self._expected_log_sequence_id = (
149            log_entries_proto.first_entry_sequence_id + messages_received
150        )
151        if dropped_log_count > 0:
152            self._handle_log_drop_count(dropped_log_count, 'loss at transport')
153        elif dropped_log_count < 0:
154            _LOG.error('Log sequence ID is smaller than expected')
155
156    def _log_entries_proto_parser(self, log_entries_proto: log_pb2.LogEntries):
157        self._check_for_dropped_logs(log_entries_proto)
158        for log_proto in log_entries_proto.entries:
159            decoded_timestamp = self.decode_timestamp(log_proto.timestamp)
160            # Parse level and convert to logging module level number.
161            level = (log_proto.line_level & 0x7) * 10
162            if self.detokenizer:
163                message = str(
164                    decode_optionally_tokenized(
165                        self.detokenizer, log_proto.message
166                    )
167                )
168            else:
169                message = log_proto.message.decode('utf-8')
170            log = FormatStringWithMetadata(message)
171
172            # Handle dropped count.
173            if log_proto.dropped:
174                drop_reason = (
175                    log_proto.message.decode('utf-8').lower()
176                    if log_proto.message
177                    else 'enqueue failure on device'
178                )
179                self._handle_log_drop_count(log_proto.dropped, drop_reason)
180                continue
181            self._emit_device_log(
182                level,
183                decoded_timestamp,
184                log.module,
185                log.message,
186                **dict(log.fields),
187            )
188
189    def _emit_device_log(
190        self,
191        level: int,
192        timestamp: str,
193        module_name: str,
194        message: str,
195        **metadata_fields,
196    ):
197        # Fields used for console table view
198        fields = metadata_fields
199        fields['timestamp'] = timestamp
200        fields['msg'] = message
201        fields['module'] = module_name
202
203        # Format used for file or stdout logging.
204        self.logger.log(
205            level,
206            '%s %s%s',
207            timestamp,
208            f'{module_name} '.lstrip(),
209            message,
210            extra=dict(extra_metadata_fields=fields),
211        )
212
213    def decode_timestamp(self, timestamp: int) -> str:
214        """Decodes timestamp to a human-readable value.
215
216        Defaults to interpreting the input timestamp as nanoseconds since boot.
217        Devices can override this to match their timestamp units.
218        """
219        if self.timestamp_decoder:
220            return self.timestamp_decoder(timestamp)
221        return str(datetime.timedelta(seconds=timestamp / 1e9))[:-3]
222
223    def get_and_log_metrics(self) -> dict:
224        """Retrieves the parsed metrics and logs them to the console."""
225        metrics = metric_parser.parse_metrics(
226            self.rpcs, self.detokenizer, self.rpc_timeout_s
227        )
228
229        def print_metrics(metrics, path):
230            """Traverses dictionaries, until a non-dict value is reached."""
231            for path_name, metric in metrics.items():
232                if isinstance(metric, dict):
233                    print_metrics(metric, path + '/' + path_name)
234                else:
235                    _LOG.info('%s/%s: %s', path, path_name, str(metric))
236
237        print_metrics(metrics, '')
238        return metrics
239
240    def snapshot_peak_stack_usage(self, thread_name: Optional[str] = None):
241        _, rsp = self.rpcs.pw.thread.ThreadSnapshotService.GetPeakStackUsage(
242            name=thread_name
243        )
244
245        thread_info = thread_pb2.SnapshotThreadInfo()
246        for thread_info_block in rsp:
247            for thread in thread_info_block.threads:
248                thread_info.threads.append(thread)
249        for line in str(ThreadSnapshotAnalyzer(thread_info)).splitlines():
250            _LOG.info('%s', line)
251