• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tooling to generate C++ constants from a yaml sensor definition."""
15
16import argparse
17import importlib.resources
18import re
19import sys
20import typing
21from collections.abc import Sequence
22from dataclasses import dataclass, fields, is_dataclass
23from typing import Any
24import types
25
26import jinja2
27import yaml
28
29
30def kid_from_name(name: str) -> str:
31    """Generate a const style ID name from a given name string.
32
33    Example:
34      If name is "sample_rate", the ID would be kSampleRate
35
36    Args:
37      name: the name to convert to an ID
38
39    Returns:
40      C++ style 'k' prefixed camel cased ID
41
42    """
43    return "k" + "".join(ele.title() for ele in re.split(r"[\s_\-\,]+", name))
44
45
46@dataclass
47class UnitsSpec:
48    """Typing for the Units definition dictionary."""
49
50    name: str
51    symbol: str
52
53
54@dataclass
55class AttributeSpec:
56    """Typing for the Attribute definition dictionary."""
57
58    name: str
59    description: str
60
61
62@dataclass
63class ChannelSpec:
64    """Typing for the Channel definition dictionary."""
65
66    name: str
67    description: str
68    units: str
69
70
71@dataclass
72class TriggerSpec:
73    """Typing for the Trigger definition dictionary."""
74
75    name: str
76    description: str
77
78
79@dataclass
80class SensorAttributeSpec:
81    """Typing for the SensorAttribute definition dictionary."""
82
83    channel: str | None
84    trigger: str | None
85    attribute: str
86    units: str
87
88
89@dataclass
90class CompatibleSpec:
91    """Typing for the Compatible dictionary."""
92
93    org: str | None
94    part: str
95
96
97@dataclass
98class SensorSpec:
99    """Typing for the Sensor definition dictionary."""
100
101    description: str
102    compatible: CompatibleSpec
103    supported_buses: list[str]
104    attributes: list[SensorAttributeSpec]
105    channels: dict[str, list[ChannelSpec]]
106    triggers: list[Any]
107    extras: dict[str, Any]
108
109
110@dataclass
111class Args:
112    """CLI arguments"""
113
114    package: Sequence[str]
115    language: str
116    zephyr: bool
117
118
119@dataclass
120class InputSpec:
121    """Typing for the InputData spec dictionary"""
122
123    units: dict[str, UnitsSpec]
124    attributes: dict[str, AttributeSpec]
125    channels: dict[str, ChannelSpec]
126    triggers: dict[str, TriggerSpec]
127    sensors: dict[str, SensorSpec]
128
129
130def is_list_type(t: Any) -> bool:  # noqa: ANN401
131    """
132    Checks if the given type `t` is a list.
133
134    Args:
135        t: The type to check.
136
137    Returns:
138        True if `t` is a list type, False otherwise.
139
140    """
141    origin = typing.get_origin(t)
142    return origin is list or (origin is list and typing.get_args(t) == ())
143
144
145def is_primitive(value: Any) -> bool:  # noqa: ANN401
146    """Checks if the given value is of a primitive type.
147
148    Args:
149        value: The value to check.
150
151    Returns:
152        True if the value is of a primitive type, False otherwise.
153
154    """
155    return isinstance(value, int | float | complex | str | bool)
156
157
158def is_union(t: Any) -> bool:  # noqa: ANN401
159    """Check if the given type is a union
160
161    Args:
162        t: The type to check.
163
164    Returns:
165        True if `t` is a union type, False otherwise.
166
167    """
168    return (
169        typing.get_origin(t) is typing.Union
170        or typing.get_origin(t) is types.UnionType
171    )
172
173
174def create_dataclass_from_dict(
175    cls: Any,
176    data: Any,
177    indent: int = 0,  # noqa: ANN401
178) -> Any:  # noqa: ANN401
179    """Recursively creates a dataclass instance from a nested dictionary."""
180    field_values: dict[str, Any] = {}
181
182    if is_list_type(cls):
183        result = []
184        for item in data:
185            result.append(  # noqa: PERF401
186                create_dataclass_from_dict(
187                    typing.get_args(cls)[0], item, indent + 2
188                )
189            )
190        return result
191
192    if is_primitive(data):
193        return data
194
195    for field in fields(cls):
196        field_value = data.get(field.name)
197        if field_value is None:
198            field_value = data.get(field.name.replace("_", "-"))
199
200        if (
201            is_union(field.type)
202            and type(None) in typing.get_args(field.type)
203            and field_value is None
204        ):
205            # We have an optional field and no value, skip it
206            field_values[field.name] = None
207            continue
208
209        assert field_value is not None
210
211        # We need to check if the field is a List, dictionary, or another
212        # dataclass. If it is, recurse.
213        if is_list_type(field.type):
214            item_type = typing.get_args(field.type)[0]
215            field_value = [
216                create_dataclass_from_dict(item_type, item, indent + 2)
217                for item in field_value
218            ]
219        elif dict in getattr(field.type, "__mro__", []):
220            # We might not have types specified in the dataclass
221            value_types = typing.get_args(field.type)
222            if len(value_types) != 0:
223                value_type = value_types[1]
224                field_value = {
225                    key: create_dataclass_from_dict(value_type, val, indent + 2)
226                    for key, val in field_value.items()
227                }
228        elif is_dataclass(field.type):
229            field_value = create_dataclass_from_dict(
230                field.type, field_value, indent + 2
231            )
232
233        field_values[field.name] = field_value
234
235    return cls(**field_values)
236
237
238def main() -> None:
239    """Main entry point
240
241    This function will:
242    - Get CLI flags
243    - Read YAML from stdin
244    - Find all attribute, channel, trigger, and unit definitions
245    - Print header
246    """
247    args = get_args()
248    yaml_input = yaml.safe_load(sys.stdin)
249    spec: InputSpec = create_dataclass_from_dict(InputSpec, yaml_input)
250
251    jinja_templates = {
252        t: importlib.resources.read_text("pw_sensor.templates", t)
253        for t in importlib.resources.contents("pw_sensor.templates")
254        if t.endswith(".jinja")
255    }
256    environment = jinja2.Environment(
257        loader=jinja2.DictLoader(jinja_templates),
258        autoescape=True,
259        # Trim whitespace in templates
260        trim_blocks=True,
261        lstrip_blocks=True,
262    )
263    environment.globals["kid_from_name"] = kid_from_name
264
265    if args.language == "cpp":
266        template = environment.get_template("cpp_constants.jinja")
267        out = template.render(
268            {
269                "spec": spec,
270                "package_name": "::".join(args.package),
271            }
272        )
273    else:
274        error = f"Invalid language selected: '{args.language}'"
275        raise ValueError(error)
276
277    sys.stdout.write(out)
278
279
280def validate_package_arg(value: str) -> str:
281    """
282    Validate that the package argument is a valid string
283
284    Args:
285      value: The package name
286
287    Returns:
288      The same value after being validated.
289
290    """
291    if value is None or value == "":
292        return value
293    if not re.match(r"[a-zA-Z_$][\w$]*(\.[a-zA-Z_$][\w$]*)*", value):
294        raise argparse.ArgumentError(
295            argument=None,
296            message=f"Invalid string {value}. Must use alphanumeric values "
297            "separated by dots.",
298        )
299    return value
300
301
302def get_args() -> Args:
303    """
304    Get CLI arguments
305
306    Returns:
307      Typed arguments class instance
308
309    """
310    parser = argparse.ArgumentParser()
311    parser.add_argument(
312        "--package",
313        "-pkg",
314        default="",
315        type=validate_package_arg,
316        help="Output package name separated by '.', example: com.google",
317    )
318    parser.add_argument(
319        "--language",
320        type=str,
321        choices=["cpp"],
322        default="cpp",
323    )
324    parser.add_argument(
325        "--zephyr",
326        action="store_true",
327    )
328    args = parser.parse_args()
329    return Args(
330        package=args.package.split("."),
331        language=args.language,
332        zephyr=args.zephyr,
333    )
334
335
336if __name__ == "__main__":
337    main()
338