• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Device classes to interact with targets via RPC."""
15
16import logging
17from pathlib import Path
18from types import ModuleType
19from typing import Any, Callable
20
21from pw_thread_protos import thread_pb2
22from pw_hdlc import rpc
23from pw_log import log_decoder
24from pw_log_rpc import rpc_log_stream
25from pw_metric import metric_parser
26import pw_rpc
27from pw_rpc import callback_client, console_tools
28from pw_thread import thread_analyzer
29from pw_tokenizer import detokenize
30from pw_tokenizer.proto import decode_optionally_tokenized
31from pw_unit_test.rpc import run_tests as pw_unit_test_run_tests, TestRecord
32
33# Internal log for troubleshooting this tool (the console).
34_LOG = logging.getLogger('tools')
35DEFAULT_DEVICE_LOGGER = logging.getLogger('rpc_device')
36
37
38# pylint: disable=too-many-arguments
39class Device:
40    """Represents an RPC Client for a device running a Pigweed target.
41
42    The target must have and RPC support, RPC logging.
43    Note: use this class as a base for specialized device representations.
44    """
45
46    def __init__(
47        self,
48        channel_id: int,
49        reader: rpc.CancellableReader,
50        write,
51        proto_library: list[ModuleType | Path],
52        detokenizer: detokenize.Detokenizer | None = None,
53        timestamp_decoder: Callable[[int], str] | None = None,
54        rpc_timeout_s: float = 5,
55        use_rpc_logging: bool = True,
56        use_hdlc_encoding: bool = True,
57        logger: logging.Logger | logging.LoggerAdapter = DEFAULT_DEVICE_LOGGER,
58    ):
59        self.channel_id = channel_id
60        self.protos = proto_library
61        self.detokenizer = detokenizer
62        self.rpc_timeout_s = rpc_timeout_s
63
64        self.logger = logger
65        self.logger.setLevel(logging.DEBUG)  # Allow all device logs through.
66
67        callback_client_impl = callback_client.Impl(
68            default_unary_timeout_s=self.rpc_timeout_s,
69            default_stream_timeout_s=None,
70        )
71
72        def detokenize_and_log_output(data: bytes, _detokenizer=None):
73            log_messages = data.decode(
74                encoding='utf-8', errors='surrogateescape'
75            )
76
77            if self.detokenizer:
78                log_messages = decode_optionally_tokenized(
79                    self.detokenizer, data
80                )
81
82            for line in log_messages.splitlines():
83                self.logger.info(line)
84
85        self.client: rpc.RpcClient
86        if use_hdlc_encoding:
87            channels = [
88                pw_rpc.Channel(self.channel_id, rpc.channel_output(write))
89            ]
90            self.client = rpc.HdlcRpcClient(
91                reader,
92                self.protos,
93                channels,
94                detokenize_and_log_output,
95                client_impl=callback_client_impl,
96            )
97        else:
98            channel = pw_rpc.Channel(self.channel_id, write)
99            self.client = rpc.NoEncodingSingleChannelRpcClient(
100                reader,
101                self.protos,
102                channel,
103                client_impl=callback_client_impl,
104            )
105
106        if use_rpc_logging:
107            # Create the log decoder used by the LogStreamHandler.
108
109            def decoded_log_handler(log: log_decoder.Log) -> None:
110                log_decoder.log_decoded_log(log, self.logger)
111
112            self._log_decoder = log_decoder.LogStreamDecoder(
113                decoded_log_handler=decoded_log_handler,
114                detokenizer=self.detokenizer,
115                source_name='RpcDevice',
116                timestamp_parser=(
117                    timestamp_decoder
118                    if timestamp_decoder
119                    else log_decoder.timestamp_parser_ns_since_boot
120                ),
121            )
122
123            # Start listening to logs as soon as possible.
124            self.log_stream_handler = rpc_log_stream.LogStreamHandler(
125                self.rpcs, self._log_decoder
126            )
127            self.log_stream_handler.listen_to_logs()
128
129    def __enter__(self):
130        return self
131
132    def __exit__(self, *exc_info):
133        self.close()
134
135    def close(self) -> None:
136        self.client.close()
137
138    def info(self) -> console_tools.ClientInfo:
139        return console_tools.ClientInfo('device', self.rpcs, self.client.client)
140
141    @property
142    def rpcs(self) -> Any:
143        """Returns an object for accessing services on the specified channel."""
144        return next(iter(self.client.client.channels())).rpcs
145
146    def run_tests(self, timeout_s: float | None = 5) -> TestRecord:
147        """Runs the unit tests on this device."""
148        return pw_unit_test_run_tests(self.rpcs, timeout_s=timeout_s)
149
150    def get_and_log_metrics(self) -> dict:
151        """Retrieves the parsed metrics and logs them to the console."""
152        metrics = metric_parser.parse_metrics(
153            self.rpcs, self.detokenizer, self.rpc_timeout_s
154        )
155
156        def print_metrics(metrics, path):
157            """Traverses dictionaries, until a non-dict value is reached."""
158            for path_name, metric in metrics.items():
159                if isinstance(metric, dict):
160                    print_metrics(metric, path + '/' + path_name)
161                else:
162                    _LOG.info('%s/%s: %s', path, path_name, str(metric))
163
164        print_metrics(metrics, '')
165        return metrics
166
167    def snapshot_peak_stack_usage(self, thread_name: str | None = None):
168        snapshot_service = self.rpcs.pw.thread.proto.ThreadSnapshotService
169        _, rsp = snapshot_service.GetPeakStackUsage(name=thread_name)
170
171        thread_info = thread_pb2.SnapshotThreadInfo()
172        for thread_info_block in rsp:
173            for thread in thread_info_block.threads:
174                thread_info.threads.append(thread)
175        for line in str(
176            thread_analyzer.ThreadSnapshotAnalyzer(thread_info)
177        ).splitlines():
178            _LOG.info('%s', line)
179