1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Device classes to interact with targets via RPC.""" 15 16import datetime 17import logging 18from pathlib import Path 19from types import ModuleType 20from typing import Any, Callable, List, Union, Optional 21 22from pw_hdlc.rpc import HdlcRpcClient, default_channels 23from pw_log_tokenized import FormatStringWithMetadata 24from pw_log.proto import log_pb2 25from pw_metric import metric_parser 26from pw_rpc import callback_client, console_tools 27from pw_status import Status 28from pw_thread.thread_analyzer import ThreadSnapshotAnalyzer 29from pw_thread_protos import thread_pb2 30from pw_tokenizer import detokenize 31from pw_tokenizer.proto import decode_optionally_tokenized 32from pw_unit_test.rpc import run_tests as pw_unit_test_run_tests 33 34# Internal log for troubleshooting this tool (the console). 35_LOG = logging.getLogger('tools') 36DEFAULT_DEVICE_LOGGER = logging.getLogger('rpc_device') 37 38 39class Device: 40 """Represents an RPC Client for a device running a Pigweed target. 41 42 The target must have and RPC support, RPC logging. 43 Note: use this class as a base for specialized device representations. 44 """ 45 46 def __init__( 47 self, 48 channel_id: int, 49 read, 50 write, 51 proto_library: List[Union[ModuleType, Path]], 52 detokenizer: Optional[detokenize.Detokenizer], 53 timestamp_decoder: Optional[Callable[[int], str]], 54 rpc_timeout_s: float = 5, 55 use_rpc_logging: bool = True, 56 ): 57 self.channel_id = channel_id 58 self.protos = proto_library 59 self.detokenizer = detokenizer 60 self.rpc_timeout_s = rpc_timeout_s 61 62 self.logger = DEFAULT_DEVICE_LOGGER 63 self.logger.setLevel(logging.DEBUG) # Allow all device logs through. 64 self.timestamp_decoder = timestamp_decoder 65 self._expected_log_sequence_id = 0 66 67 callback_client_impl = callback_client.Impl( 68 default_unary_timeout_s=self.rpc_timeout_s, 69 default_stream_timeout_s=None, 70 ) 71 72 def detokenize_and_log_output(data: bytes, _detokenizer=None): 73 log_messages = data.decode( 74 encoding='utf-8', errors='surrogateescape' 75 ) 76 77 if self.detokenizer: 78 log_messages = decode_optionally_tokenized( 79 self.detokenizer, data 80 ) 81 82 for line in log_messages.splitlines(): 83 self.logger.info(line) 84 85 self.client = HdlcRpcClient( 86 read, 87 self.protos, 88 default_channels(write), 89 detokenize_and_log_output, 90 client_impl=callback_client_impl, 91 ) 92 93 if use_rpc_logging: 94 # Start listening to logs as soon as possible. 95 self.listen_to_log_stream() 96 97 def info(self) -> console_tools.ClientInfo: 98 return console_tools.ClientInfo('device', self.rpcs, self.client.client) 99 100 @property 101 def rpcs(self) -> Any: 102 """Returns an object for accessing services on the specified channel.""" 103 return next(iter(self.client.client.channels())).rpcs 104 105 def run_tests(self, timeout_s: Optional[float] = 5) -> bool: 106 """Runs the unit tests on this device.""" 107 return pw_unit_test_run_tests(self.rpcs, timeout_s=timeout_s) 108 109 def listen_to_log_stream(self): 110 """Opens a log RPC for the device's unrequested log stream. 111 112 The RPCs remain open until the server cancels or closes them, either 113 with a response or error packet. 114 """ 115 self.rpcs.pw.log.Logs.Listen.open( 116 on_next=lambda _, log_entries_proto: self._log_entries_proto_parser( 117 log_entries_proto 118 ), 119 on_completed=lambda _, status: _LOG.info( 120 'Log stream completed with status: %s', status 121 ), 122 on_error=lambda _, error: self._handle_log_stream_error(error), 123 ) 124 125 def _handle_log_stream_error(self, error: Status): 126 """Resets the log stream RPC on error to avoid losing logs.""" 127 _LOG.error('Log stream error: %s', error) 128 129 # Only re-request logs if the RPC was not cancelled by the client. 130 if error != Status.CANCELLED: 131 self.listen_to_log_stream() 132 133 def _handle_log_drop_count(self, drop_count: int, reason: str): 134 log_text = 'log' if drop_count == 1 else 'logs' 135 message = f'Dropped {drop_count} {log_text} due to {reason}' 136 self._emit_device_log(logging.WARNING, '', '', message) 137 138 def _check_for_dropped_logs(self, log_entries_proto: log_pb2.LogEntries): 139 # Count log messages received that don't use the dropped field. 140 messages_received = sum( 141 1 if not log_proto.dropped else 0 142 for log_proto in log_entries_proto.entries 143 ) 144 dropped_log_count = ( 145 log_entries_proto.first_entry_sequence_id 146 - self._expected_log_sequence_id 147 ) 148 self._expected_log_sequence_id = ( 149 log_entries_proto.first_entry_sequence_id + messages_received 150 ) 151 if dropped_log_count > 0: 152 self._handle_log_drop_count(dropped_log_count, 'loss at transport') 153 elif dropped_log_count < 0: 154 _LOG.error('Log sequence ID is smaller than expected') 155 156 def _log_entries_proto_parser(self, log_entries_proto: log_pb2.LogEntries): 157 self._check_for_dropped_logs(log_entries_proto) 158 for log_proto in log_entries_proto.entries: 159 decoded_timestamp = self.decode_timestamp(log_proto.timestamp) 160 # Parse level and convert to logging module level number. 161 level = (log_proto.line_level & 0x7) * 10 162 if self.detokenizer: 163 message = str( 164 decode_optionally_tokenized( 165 self.detokenizer, log_proto.message 166 ) 167 ) 168 else: 169 message = log_proto.message.decode('utf-8') 170 log = FormatStringWithMetadata(message) 171 172 # Handle dropped count. 173 if log_proto.dropped: 174 drop_reason = ( 175 log_proto.message.decode('utf-8').lower() 176 if log_proto.message 177 else 'enqueue failure on device' 178 ) 179 self._handle_log_drop_count(log_proto.dropped, drop_reason) 180 continue 181 self._emit_device_log( 182 level, 183 decoded_timestamp, 184 log.module, 185 log.message, 186 **dict(log.fields), 187 ) 188 189 def _emit_device_log( 190 self, 191 level: int, 192 timestamp: str, 193 module_name: str, 194 message: str, 195 **metadata_fields, 196 ): 197 # Fields used for console table view 198 fields = metadata_fields 199 fields['timestamp'] = timestamp 200 fields['msg'] = message 201 fields['module'] = module_name 202 203 # Format used for file or stdout logging. 204 self.logger.log( 205 level, 206 '%s %s%s', 207 timestamp, 208 f'{module_name} '.lstrip(), 209 message, 210 extra=dict(extra_metadata_fields=fields), 211 ) 212 213 def decode_timestamp(self, timestamp: int) -> str: 214 """Decodes timestamp to a human-readable value. 215 216 Defaults to interpreting the input timestamp as nanoseconds since boot. 217 Devices can override this to match their timestamp units. 218 """ 219 if self.timestamp_decoder: 220 return self.timestamp_decoder(timestamp) 221 return str(datetime.timedelta(seconds=timestamp / 1e9))[:-3] 222 223 def get_and_log_metrics(self) -> dict: 224 """Retrieves the parsed metrics and logs them to the console.""" 225 metrics = metric_parser.parse_metrics( 226 self.rpcs, self.detokenizer, self.rpc_timeout_s 227 ) 228 229 def print_metrics(metrics, path): 230 """Traverses dictionaries, until a non-dict value is reached.""" 231 for path_name, metric in metrics.items(): 232 if isinstance(metric, dict): 233 print_metrics(metric, path + '/' + path_name) 234 else: 235 _LOG.info('%s/%s: %s', path, path_name, str(metric)) 236 237 print_metrics(metrics, '') 238 return metrics 239 240 def snapshot_peak_stack_usage(self, thread_name: Optional[str] = None): 241 _, rsp = self.rpcs.pw.thread.ThreadSnapshotService.GetPeakStackUsage( 242 name=thread_name 243 ) 244 245 thread_info = thread_pb2.SnapshotThreadInfo() 246 for thread_info_block in rsp: 247 for thread in thread_info_block.threads: 248 thread_info.threads.append(thread) 249 for line in str(ThreadSnapshotAnalyzer(thread_info)).splitlines(): 250 _LOG.info('%s', line) 251