• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/bin/sh
2# Copyright (C) 2022 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16""":" # Shell script (in docstring to appease pylint)
17# Find and invoke hermetic python3 interpreter
18. "`dirname $0`"/../../../../../"trusty/vendor/google/aosp/scripts/envsetup.sh"
19exec "$PY3" "$0" "$@"
20# Shell script end
21
22Generate metrics stats functions for all messages defined in a .proto file
23
24Command line (per protoc requirements):
25    $ROOT_DIR/prebuilts/libprotobuf/bin/protoc \
26        --proto_path=$SCRIPT_DIR \
27        --plugin=metrics_atoms=metrics_atoms_protoc_plugin.py \
28        --metrics_atoms_out=out \
29        --metrics_atoms_opt=pkg:android/trusty/stats \
30        test/atoms.proto
31
32Important:
33* For debugging purposes: set option `dump-input` as show below:
34        --stats-log_opt=pkg:android/trusty/stats,dump-input:/tmp/dump.pb
35
36* with `debug` option, Invoke the protoc command line and observe
37  the creation of `/tmp/dump.pf`.
38
39* Then invoke the debugging script
40  `metrics_atoms_protoc_debug.py /tmp/dump.pb android/trusty/stats`,
41  and hook the debugger at your convenience.
42
43This is the easiest debugging approach.
44It still is possible to debug when protoc is invoking the plugin,
45it requires enabling remote debugging, which is slightly more tedious.
46
47"""
48
49import functools
50import os
51import re
52import sys
53from pathlib import Path
54from dataclasses import dataclass
55from enum import Enum
56from typing import Dict, List
57
58# mypy: disable-error-code="attr-defined,valid-type"
59from google.protobuf.compiler import plugin_pb2 as plugin
60from google.protobuf.descriptor_pb2 import FileDescriptorProto
61from google.protobuf.descriptor_pb2 import DescriptorProto
62from google.protobuf.descriptor_pb2 import FieldDescriptorProto
63from jinja2 import Environment, PackageLoader
64
65DEBUG_READ_PB = False
66
67jinja_template_loader = PackageLoader("templates_package")
68jinja_template_env = Environment(
69    loader=jinja_template_loader,
70    trim_blocks=True,
71    lstrip_blocks=True,
72    keep_trailing_newline=True,
73    line_comment_prefix="###",
74)
75
76
77def snake_case(s: str):
78    """Convert impl name from camelCase to snake_case"""
79    return re.sub(r"(?<!^)(?=[A-Z])", "_", s).lower()
80
81
82class StatsLogGenerationError(Exception):
83    """"Error preventing the log API generation"""
84
85
86@dataclass
87class VendorAtomEnumValue:
88    name: str
89    value: int
90
91    @functools.cached_property
92    def name_len(self):
93        return len(self.name)
94
95
96@dataclass
97class VendorAtomEnum:
98    name: str
99    values: List[VendorAtomEnumValue]
100
101    @functools.cached_property
102    def values_name_len(self):
103        return max(len(v.name) for v in self.values)
104
105    @functools.cached_property
106    def c_name(self):
107        return f"stats_{snake_case(self.name)}"
108
109
110class VendorAtomValueTag(Enum):
111    intValue = 0
112    longValue = 1
113    floatValue = 2
114    stringValue = 3
115
116    @classmethod
117    def get_tag(cls, label: FieldDescriptorProto,
118                type_: FieldDescriptorProto,
119                type_name: str) -> 'VendorAtomValueTag':
120        if label == FieldDescriptorProto.LABEL_REPEATED:
121            raise StatsLogGenerationError(
122                f"repeated fields are not supported in Android"
123                f" please fix {type_name}({type_})")
124        if type_ in [
125                FieldDescriptorProto.TYPE_DOUBLE,
126                FieldDescriptorProto.TYPE_FLOAT
127        ]:
128            return VendorAtomValueTag.floatValue
129        if type_ in [
130                FieldDescriptorProto.TYPE_INT32,
131                FieldDescriptorProto.TYPE_SINT32,
132                FieldDescriptorProto.TYPE_UINT32,
133                FieldDescriptorProto.TYPE_FIXED32,
134                FieldDescriptorProto.TYPE_ENUM,
135        ]:
136            return VendorAtomValueTag.intValue
137        if type_ in [
138                FieldDescriptorProto.TYPE_INT64,
139                FieldDescriptorProto.TYPE_SINT64,
140                FieldDescriptorProto.TYPE_UINT64,
141                FieldDescriptorProto.TYPE_FIXED64
142        ]:
143            return VendorAtomValueTag.longValue
144        if type_ in [
145                FieldDescriptorProto.TYPE_BOOL,
146        ]:
147            raise StatsLogGenerationError(
148                f"boolean fields are not supported in Android"
149                f" please fix {type_name}({type_})")
150        if type_ in [
151                FieldDescriptorProto.TYPE_STRING,
152        ]:
153            return VendorAtomValueTag.stringValue
154        if type_ in [
155                FieldDescriptorProto.TYPE_BYTES,
156        ]:
157            raise StatsLogGenerationError(
158                f"byte[] fields are not supported in Android"
159                f" please fix {type_name}({type_})")
160        raise StatsLogGenerationError(
161            f"field type {type_name}({type_}) cannot be an atom field")
162
163
164@dataclass
165class VendorAtomValue:
166    name: str
167    tag: VendorAtomValueTag
168    enum: VendorAtomEnum
169    idx: int
170
171    @functools.cached_property
172    def c_name(self):
173        return snake_case(self.name)
174
175    @functools.cached_property
176    def is_string(self):
177        return self.tag in [
178            VendorAtomValueTag.stringValue,
179        ]
180
181    @functools.cached_property
182    def c_type(self):
183        if self.enum:
184            return f"enum {self.enum.c_name} "
185        match self.tag:
186            case VendorAtomValueTag.intValue:
187                return 'int32_t '
188            case VendorAtomValueTag.longValue:
189                return 'int64_t '
190            case VendorAtomValueTag.floatValue:
191                return 'float '
192            case VendorAtomValueTag.stringValue:
193                return 'const char *'
194            case _:
195                raise StatsLogGenerationError(f"unknown tag {self.tag}")
196
197    @functools.cached_property
198    def default_value(self):
199        if self.enum:
200            try:
201                default = [
202                    v.name
203                    for v in self.enum.values
204                    if v.name.lower().find("invalid") > -1 or
205                    v.name.lower().find("unknown") > -1
206                ][0]
207            except IndexError:
208                default = '0'
209            return default
210        match self.tag:
211            case VendorAtomValueTag.intValue:
212                return '0'
213            case VendorAtomValueTag.longValue:
214                return '0L'
215            case VendorAtomValueTag.floatValue:
216                return '0.'
217            case VendorAtomValueTag.stringValue:
218                return '"", 0UL'
219            case _:
220                raise StatsLogGenerationError(f"unknown tag {self.tag}")
221
222    @functools.cached_property
223    def stats_setter_name(self):
224        match self.tag:
225            case VendorAtomValueTag.intValue:
226                return 'set_int_value_at'
227            case VendorAtomValueTag.longValue:
228                return 'set_long_value_at'
229            case VendorAtomValueTag.floatValue:
230                return 'set_float_value_at'
231            case VendorAtomValueTag.stringValue:
232                return 'set_string_value_at'
233            case _:
234                raise StatsLogGenerationError(f"unknown tag {self.tag}")
235
236
237@dataclass
238class VendorAtom:
239    name: str
240    atom_id: int
241    values: List[VendorAtomValue]
242
243    @functools.cached_property
244    def c_name(self):
245        return snake_case(self.name)
246
247
248class VendorAtomEnv:
249    """Static class gathering all enums and atoms required for code generation
250    """
251    enums: Dict[str, VendorAtomEnum]
252    atoms: List[VendorAtom]
253
254    @classmethod
255    def len(cls, ll: List):
256        return len(ll)
257
258    @classmethod
259    def snake_case(cls, s: str):
260        return snake_case(s)
261
262
263def assert_reverse_domain_name_field(msg_dict: Dict[str, DescriptorProto],
264                                     atom: DescriptorProto):
265    """verify the assumption that reverse_domain_name is also an atom.field
266    of type FieldDescriptorProto.TYPE_MESSAGE which we can exclude
267    (see make_atom) from the VendorAtomValue list
268    """
269    reverse_domain_name_idx = [[
270        idx
271        for idx, ff in enumerate(msg_dict[f.type_name].field)
272        if ff.name == "reverse_domain_name"
273    ]
274                               for f in atom.field
275                               if f.type == FieldDescriptorProto.TYPE_MESSAGE]
276    for idx_list in reverse_domain_name_idx:
277        assert (len(idx_list) == 1 and idx_list[0] == 0)
278
279
280def get_enum(f: FieldDescriptorProto):
281    if f.type == FieldDescriptorProto.TYPE_ENUM:
282        return VendorAtomEnv.enums[f.type_name]
283    return None
284
285
286def make_atom(msg_dict: Dict[str, DescriptorProto],
287              field: FieldDescriptorProto):
288    """Each field in the Atom message, pointing to
289    a message, are atoms for which we need to generate the
290    stats_log function.
291    The `field.number` here is the atomId uniquely
292    identifying the VendorAtom
293    All fields except the reverse_domain_name field are
294    added as VendorAtomValue.
295    """
296    assert field.type == FieldDescriptorProto.TYPE_MESSAGE
297    return VendorAtom(msg_dict[field.type_name].name, field.number, [
298        VendorAtomValue(name=ff.name,
299                        tag=VendorAtomValueTag.get_tag(ff.label, ff.type,
300                                                       ff.type_name),
301                        enum=get_enum(ff),
302                        idx=idx - 1)
303        for idx, ff in enumerate(msg_dict[field.type_name].field)
304        if ff.name != "reverse_domain_name"
305    ])
306
307
308def process_file(proto_file: FileDescriptorProto,
309                 response: plugin.CodeGeneratorResponse,
310                 pkg: str = '') -> None:
311
312    def get_uri(type_name: str):
313        paths = [x for x in [proto_file.package, type_name] if len(x) > 0]
314        return f".{'.'.join(paths)}"
315
316    msg_list = list(proto_file.message_type)
317    msg_dict = {get_uri(msg.name): msg for msg in msg_list}
318
319    # Get Atom message and parse its atom fields
320    # recording the atomId in the process
321    try:
322        atom = [msg for msg in msg_list if msg.name == "Atom"][0]
323    except IndexError as e:
324        raise StatsLogGenerationError(
325            f"the Atom message is missing from {proto_file.name}") from e
326
327    VendorAtomEnv.enums = {
328        get_uri(e.name): VendorAtomEnum(name=e.name,
329                                        values=[
330                                            VendorAtomEnumValue(name=ee.name,
331                                                                value=ee.number)
332                                            for ee in e.value
333                                        ]) for e in proto_file.enum_type
334    }
335
336    assert_reverse_domain_name_field(msg_dict, atom)
337    VendorAtomEnv.atoms = [
338        make_atom(msg_dict, field)
339        for field in atom.field
340        if field.type == FieldDescriptorProto.TYPE_MESSAGE
341    ]
342    proto_name = Path(proto_file.name).stem
343    for item in [
344        {"tpl":"metrics_atoms.c.j2", "ext":'c'},
345        {"tpl":"metrics_atoms.h.j2", "ext":'h'},
346    ]:
347        tm = jinja_template_env.get_template(item["tpl"])
348        tm_env = {"env":VendorAtomEnv}
349        rendered = tm.render(**tm_env)
350        file = response.file.add()
351        file_path = pkg.split('/')
352        if item['ext'] == 'h':
353            file_path.insert(0, 'include')
354        file_path.append(f"{proto_name}.{item['ext']}")
355        file.name = os.path.join(*file_path)
356        file.content = rendered
357
358
359def process_data(data: bytes, pkg: str = '') -> None:
360    request = plugin.CodeGeneratorRequest()
361    request.ParseFromString(data)
362
363    dump_input_file = None
364    options = request.parameter.split(',') if request.parameter else []
365    for opt in options:
366        match opt.split(':'):
367            case ["pkg", value]:
368                pkg = value
369            case ["dump-input", value]:
370                dump_input_file = value
371            case [""]:
372                pass
373            case other:
374                raise ValueError(f"unknown parameter {other}")
375
376    if dump_input_file:
377        # store the pb file for easy debug
378        with open(dump_input_file, "wb") as f_data:
379            f_data.write(data)
380
381    # Create a response
382    response = plugin.CodeGeneratorResponse()
383
384    for proto_file in request.proto_file:
385        process_file(proto_file, response, pkg)
386
387    # Serialize response and write to stdout
388    output = response.SerializeToString()
389
390    # Write to stdout per the protoc plugin expectation
391    # (protoc consumes this output)
392    sys.stdout.buffer.write(output)
393
394
395def main() -> None:
396    data = sys.stdin.buffer.read()
397    process_data(data)
398
399
400if __name__ == "__main__":
401    main()
402