• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Avatar metrics trace."""
16
17import atexit
18import time
19import types
20
21from avatar.metrics.trace_pb2 import DebugAnnotation
22from avatar.metrics.trace_pb2 import ProcessDescriptor
23from avatar.metrics.trace_pb2 import ThreadDescriptor
24from avatar.metrics.trace_pb2 import Trace
25from avatar.metrics.trace_pb2 import TracePacket
26from avatar.metrics.trace_pb2 import TrackDescriptor
27from avatar.metrics.trace_pb2 import TrackEvent
28from google.protobuf import any_pb2
29from google.protobuf import message
30from mobly.base_test import BaseTestClass
31from pathlib import Path
32from typing import TYPE_CHECKING, Any, Dict, List, Optional, Protocol, Tuple, Union
33
34if TYPE_CHECKING:
35    from avatar import PandoraDevices
36    from avatar.pandora_client import PandoraClient
37else:
38    PandoraClient = object
39    PandoraDevices = object
40
41devices_id: Dict[PandoraClient, int] = {}
42devices_process_id: Dict[PandoraClient, int] = {}
43packets: List[TracePacket] = []
44genesis: int = time.monotonic_ns()
45output_path: Optional[Path] = None
46id: int = 0
47
48
49def next_id() -> int:
50    global id
51    id += 1
52    return id
53
54
55@atexit.register
56def dump_trace() -> None:
57    global packets, output_path
58    if output_path is None:
59        return
60    trace = Trace(packet=packets)
61    with open(output_path / "avatar.trace", "wb") as f:
62        f.write(trace.SerializeToString())
63
64
65def hook_test(test: BaseTestClass, devices: PandoraDevices) -> None:
66    global packets, output_path
67
68    if output_path is None:
69        mobly_output_path: str = test.current_test_info.output_path  # type: ignore
70        output_path = (Path(mobly_output_path) / '..' / '..').resolve()  # skip test class and method name
71
72    original_setup_test = test.setup_test
73
74    def setup_test(self: BaseTestClass) -> None:
75        global genesis
76        genesis = time.monotonic_ns()
77        process_id = next_id()
78        packets.append(
79            TracePacket(
80                track_descriptor=TrackDescriptor(
81                    uuid=process_id,
82                    process=ProcessDescriptor(
83                        pid=process_id, process_name=f"{self.__class__.__name__}.{self.current_test_info.name}"
84                    ),
85                )
86            )
87        )
88
89        for device in devices:
90            devices_process_id[device] = process_id
91            devices_id[device] = next_id()
92            descriptor = TrackDescriptor(
93                uuid=devices_id[device],
94                parent_uuid=process_id,
95                thread=ThreadDescriptor(thread_name=device.name, pid=process_id, tid=devices_id[device]),
96            )
97            packets.append(TracePacket(track_descriptor=descriptor))
98
99        original_setup_test()
100
101    test.setup_test = types.MethodType(setup_test, test)
102
103
104class AsTrace(Protocol):
105    def as_trace(self) -> TracePacket: ...
106
107
108class Callsite(AsTrace):
109    id_counter = 0
110
111    @classmethod
112    def next_id(cls) -> int:
113        cls.id_counter += 1
114        return cls.id_counter
115
116    def __init__(self, device: PandoraClient, name: Union[bytes, str], message: Any) -> None:
117        self.at = time.monotonic_ns() - genesis
118        self.name = name if isinstance(name, str) else name.decode('utf-8')
119        self.device = device
120        self.message = message
121        self.events: List[CallEvent] = []
122        self.id = Callsite.next_id()
123
124        device.log.info(f"{self}")
125
126    def pretty(self) -> str:
127        name_pretty = self.name[1:].split('.')[-1].replace('/', '.')
128        if self.message is None:
129            return f"%{self.id} {name_pretty}"
130        message_pretty, _ = debug_message(self.message)
131        return f"{name_pretty}({message_pretty})"
132
133    def __str__(self) -> str:
134        return f"{str2color('╭──', self.id)} {self.pretty()}"
135
136    def output(self, message: Any) -> None:
137        self.events.append(CallOutput(self, message))
138
139    def input(self, message: Any) -> None:
140        self.events.append(CallInput(self, message))
141
142    def end(self, message: Any) -> None:
143        global packets
144        if self.device not in devices_id:
145            return
146        self.events.append(CallEnd(self, message))
147        packets.append(self.as_trace())
148        for event in self.events:
149            packets.append(event.as_trace())
150
151    def as_trace(self) -> TracePacket:
152        return TracePacket(
153            timestamp=self.at,
154            track_event=TrackEvent(
155                name=self.name,
156                type=TrackEvent.Type.TYPE_SLICE_BEGIN,
157                track_uuid=devices_id[self.device],
158                debug_annotations=(
159                    None
160                    if self.message is None
161                    else [
162                        DebugAnnotation(
163                            name=self.message.__class__.__name__, dict_entries=debug_message(self.message)[1]
164                        )
165                    ]
166                ),
167            ),
168            trusted_packet_sequence_id=devices_process_id[self.device],
169        )
170
171
172class CallEvent(AsTrace):
173    def __init__(self, callsite: Callsite, message: Any) -> None:
174        self.at = time.monotonic_ns() - genesis
175        self.callsite = callsite
176        self.message = message
177
178        callsite.device.log.info(f"{self}")
179
180    def __str__(self) -> str:
181        return f"{str2color('╰──', self.callsite.id)} {self.stringify('⟶ ')}"
182
183    def as_trace(self) -> TracePacket:
184        return TracePacket(
185            timestamp=self.at,
186            track_event=TrackEvent(
187                name=self.callsite.name,
188                type=TrackEvent.Type.TYPE_INSTANT,
189                track_uuid=devices_id[self.callsite.device],
190                debug_annotations=(
191                    None
192                    if self.message is None
193                    else [
194                        DebugAnnotation(
195                            name=self.message.__class__.__name__, dict_entries=debug_message(self.message)[1]
196                        )
197                    ]
198                ),
199            ),
200            trusted_packet_sequence_id=devices_process_id[self.callsite.device],
201        )
202
203    def stringify(self, direction: str) -> str:
204        message_pretty = "" if self.message is None else debug_message(self.message)[0]
205        return (
206            str2color(f"[{(self.at - self.callsite.at) / 1000000000:.3f}s]", self.callsite.id)
207            + f" {self.callsite.pretty()} {str2color(direction, self.callsite.id)} ({message_pretty})"
208        )
209
210
211class CallOutput(CallEvent):
212    def __str__(self) -> str:
213        return f"{str2color('├──', self.callsite.id)} {self.stringify('⟶ ')}"
214
215    def as_trace(self) -> TracePacket:
216        return super().as_trace()
217
218
219class CallInput(CallEvent):
220    def __str__(self) -> str:
221        return f"{str2color('├──', self.callsite.id)} {self.stringify('⟵ ')}"
222
223    def as_trace(self) -> TracePacket:
224        return super().as_trace()
225
226
227class CallEnd(CallEvent):
228    def __str__(self) -> str:
229        return f"{str2color('╰──', self.callsite.id)} {self.stringify('⟶ ')}"
230
231    def as_trace(self) -> TracePacket:
232        return TracePacket(
233            timestamp=self.at,
234            track_event=TrackEvent(
235                name=self.callsite.name,
236                type=TrackEvent.Type.TYPE_SLICE_END,
237                track_uuid=devices_id[self.callsite.device],
238                debug_annotations=(
239                    None
240                    if self.message is None
241                    else [
242                        DebugAnnotation(
243                            name=self.message.__class__.__name__, dict_entries=debug_message(self.message)[1]
244                        )
245                    ]
246                ),
247            ),
248            trusted_packet_sequence_id=devices_process_id[self.callsite.device],
249        )
250
251
252def debug_value(v: Any) -> Tuple[Any, Dict[str, Any]]:
253    if isinstance(v, any_pb2.Any):
254        return '...', {'string_value': f'{v}'}
255    elif isinstance(v, message.Message):
256        json, entries = debug_message(v)
257        return json, {'dict_entries': entries}
258    elif isinstance(v, bytes):
259        return (v if len(v) < 16 else '...'), {'string_value': f'{v!r}'}
260    elif isinstance(v, bool):
261        return v, {'bool_value': v}
262    elif isinstance(v, int):
263        return v, {'int_value': v}
264    elif isinstance(v, float):
265        return v, {'double_value': v}
266    elif isinstance(v, str):
267        return v, {'string_value': v}
268    try:
269        return v, {'array_values': [DebugAnnotation(**debug_value(x)[1]) for x in v]}  # type: ignore
270    except:
271        return v, {'string_value': f'{v}'}
272
273
274def debug_message(msg: message.Message) -> Tuple[Dict[str, Any], List[DebugAnnotation]]:
275    json: Dict[str, Any] = {}
276    dbga: List[DebugAnnotation] = []
277    for f, v in msg.ListFields():
278        if (
279            isinstance(v, bytes)
280            and len(v) == 6
281            and ('address' in f.name or (f.containing_oneof and 'address' in f.containing_oneof.name))
282        ):
283            addr = ':'.join([f'{x:02X}' for x in v])
284            json[f.name] = addr
285            dbga.append(DebugAnnotation(name=f.name, string_value=addr))
286        else:
287            json_entry, dbga_entry = debug_value(v)
288            json[f.name] = json_entry
289            dbga.append(DebugAnnotation(name=f.name, **dbga_entry))
290    return json, dbga
291
292
293def str2color(s: str, id: int) -> str:
294    CSI = "\x1b["
295    CSI_RESET = CSI + "0m"
296    CSI_BOLD = CSI + "1m"
297    color = ((id * 10) % (230 - 17)) + 17
298    return CSI + ("1;38;5;%dm" % color) + CSI_BOLD + s + CSI_RESET
299