• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tooling to generate C++ constants from a yaml sensor definition."""
15
16import argparse
17from dataclasses import dataclass
18from collections.abc import Sequence
19import io
20import re
21import sys
22
23import yaml
24
25
26@dataclass(frozen=True)
27class Printable:
28    """Common printable object"""
29
30    id: str
31    name: str
32    description: str | None
33
34    @property
35    def variable_name(self) -> str:
36        return "k" + ''.join(
37            ele.title() for ele in re.split(r"[\s_-]+", self.id)
38        )
39
40    def print(self, writer: io.TextIOWrapper) -> None:
41        writer.write(
42            f"""
43/// @var k{self.variable_name}
44/// @brief {self.name}
45"""
46        )
47        if self.description:
48            writer.write(
49                f"""///
50/// {self.description}
51"""
52            )
53
54
55@dataclass(frozen=True)
56class Units:
57    """A single unit representation"""
58
59    name: str
60    symbol: str
61
62
63@dataclass(frozen=True)
64class Attribute(Printable):
65    """A single attribute representation."""
66
67    units: Units
68
69    def print(self, writer: io.TextIOWrapper) -> None:
70        """Print header definition for this attribute"""
71        super().print(writer=writer)
72        writer.write(
73            f"""
74PW_SENSOR_ATTRIBUTE_TYPE(
75    static,
76    {super().variable_name},
77    "PW_SENSOR_ATTRIBUTE_TYPE",
78    "{self.name}",
79    "{self.units.symbol}"
80);
81"""
82        )
83
84
85@dataclass(frozen=True)
86class Channel(Printable):
87    """A single channel representation."""
88
89    units: Units
90
91    def print(self, writer: io.TextIOWrapper) -> None:
92        """Print header definition for this channel"""
93        super().print(writer=writer)
94        writer.write(
95            f"""
96PW_SENSOR_MEASUREMENT_TYPE(
97    static,
98    {super().variable_name},
99    "PW_SENSOR_MEASUREMENT_TYPE",
100    "{self.name}",
101    "{self.units.symbol}"
102);
103"""
104        )
105
106
107@dataclass(frozen=True)
108class Trigger(Printable):
109    """A single trigger representation."""
110
111    id: str
112    name: str
113    description: str
114
115    def print(self, writer: io.TextIOWrapper) -> None:
116        """Print header definition for this trigger"""
117        super().print(writer=writer)
118        writer.write(
119            f"""
120PW_SENSOR_TRIGGER_TYPE(
121    static,
122    {super().variable_name},
123    "PW_SENSOR_TRIGGER_TYPE",
124    "{self.name}"
125);
126"""
127        )
128
129
130@dataclass
131class Args:
132    """CLI arguments"""
133
134    package: Sequence[str]
135    language: str
136
137
138def attribute_from_dict(attribute_id: str, definition: dict) -> Attribute:
139    """Construct an Attribute from a dictionary entry."""
140    return Attribute(
141        id=attribute_id,
142        name=definition["name"],
143        description=definition["description"],
144        units=Units(
145            name=definition["units"]["name"],
146            symbol=definition["units"]["symbol"],
147        ),
148    )
149
150
151def channel_from_dict(channel_id: str, definition: dict) -> Channel:
152    """Construct a Channel from a dictionary entry."""
153    return Channel(
154        id=channel_id,
155        name=definition["name"],
156        description=definition["description"],
157        units=Units(
158            name=definition["units"]["name"],
159            symbol=definition["units"]["symbol"],
160        ),
161    )
162
163
164def trigger_from_dict(trigger_id: str, definition: dict) -> Trigger:
165    """Construct a Trigger from a dictionary entry."""
166    return Trigger(
167        id=trigger_id,
168        name=definition["name"],
169        description=definition["description"],
170    )
171
172
173class CppHeader:
174    """Generator for a C++ header"""
175
176    def __init__(
177        self,
178        package: Sequence[str],
179        attributes: Sequence[Attribute],
180        channels: Sequence[Channel],
181        triggers: Sequence[Trigger],
182    ) -> None:
183        """
184        Args:
185          package: The package name used in the output. In C++ we'll convert
186            this to a namespace.
187          units: A sequence of units which should be exposed as
188            ::pw::sensor::MeasurementType.
189        """
190        self._package: str = '::'.join(package)
191        self._attributes: Sequence[Attribute] = attributes
192        self._channels: Sequence[Channel] = channels
193        self._triggers: Sequence[Trigger] = triggers
194
195    def __str__(self) -> str:
196        writer = io.StringIO()
197        self._print_header(writer=writer)
198        self._print_constants(writer=writer)
199        self._print_footer(writer=writer)
200        return writer.getvalue()
201
202    def _print_header(self, writer: io.TextIOWrapper) -> None:
203        """
204        Print the top part of the .h file (pragma, includes, and namespace)
205
206        Args:
207          writer: Where to write the text to
208        """
209        writer.write(
210            "/* Auto-generated file, do not edit */\n"
211            "#pragma once\n"
212            "\n"
213            "#include \"pw_sensor/types.h\"\n"
214        )
215        if self._package:
216            writer.write(f"namespace {self._package} {{\n\n")
217
218    def _print_constants(self, writer: io.TextIOWrapper) -> None:
219        """
220        Print the constants definitions from self._attributes, self._channels,
221        and self._trigger
222
223        Args:
224            writer: Where to write the text
225        """
226
227        writer.write("namespace attributes {\n")
228        for attribute in self._attributes:
229            attribute.print(writer)
230        writer.write("}  // namespace attributes\n")
231        writer.write("namespace channels {\n")
232        for channel in self._channels:
233            channel.print(writer)
234        writer.write("}  // namespace channels\n")
235        writer.write("namespace triggers {\n")
236        for trigger in self._triggers:
237            trigger.print(writer)
238        writer.write("}  // namespace triggers\n")
239
240    def _print_footer(self, writer: io.TextIOWrapper) -> None:
241        """
242        Write the bottom part of the .h file (closing namespace)
243
244        Args:
245            writer: Where to write the text
246        """
247        if self._package:
248            writer.write(f"\n}}  // namespace {self._package}")
249
250
251def main() -> None:
252    """
253    Main entry point, this function will:
254    - Get CLI flags
255    - Read YAML from stdin
256    - Find all channel definitions
257    - Print header
258    """
259    args = get_args()
260    spec = yaml.safe_load(sys.stdin)
261    all_attributes: set[Attribute] = set()
262    all_channels: set[Channel] = set()
263    all_triggers: set[Trigger] = set()
264    for attribute_id, definition in spec["attributes"].items():
265        attribute = attribute_from_dict(
266            attribute_id=attribute_id, definition=definition
267        )
268        assert not attribute in all_attributes
269        all_attributes.add(attribute)
270    for channel_id, definition in spec["channels"].items():
271        channel = channel_from_dict(
272            channel_id=channel_id, definition=definition
273        )
274        assert not channel in all_channels
275        all_channels.add(channel)
276    for trigger_id, definition in spec["triggers"].items():
277        trigger = trigger_from_dict(
278            trigger_id=trigger_id, definition=definition
279        )
280        assert not trigger in all_triggers
281        all_triggers.add(trigger)
282
283    if args.language == "cpp":
284        out = CppHeader(
285            package=args.package,
286            attributes=list(all_attributes),
287            channels=list(all_channels),
288            triggers=list(all_triggers),
289        )
290    else:
291        raise ValueError(f"Invalid language selected: '{args.language}'")
292    print(out)
293
294
295def validate_package_arg(value: str) -> str:
296    """
297    Validate that the package argument is a valid string
298    """
299    if value is None or value == "":
300        return value
301    if not re.match(r"[a-zA-Z_$][\w$]*(\.[a-zA-Z_$][\w$]*)*", value):
302        raise argparse.ArgumentError(
303            argument=None,
304            message=f"Invalid string {value}. Must use alphanumeric values "
305            "separated by dots.",
306        )
307    return value
308
309
310def get_args() -> Args:
311    """
312    Get CLI arguments
313
314    Returns:
315      Typed arguments class instance
316    """
317    parser = argparse.ArgumentParser()
318    parser.add_argument(
319        "--package",
320        "-pkg",
321        default="",
322        type=validate_package_arg,
323        help="Output package name separated by '.', example: com.google",
324    )
325    parser.add_argument(
326        "--language",
327        type=str,
328        choices=["cpp"],
329        default="cpp",
330    )
331    args = parser.parse_args()
332    return Args(package=args.package.split("."), language=args.language)
333
334
335if __name__ == "__main__":
336    main()
337