1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tooling to generate C++ constants from a yaml sensor definition.""" 15 16import argparse 17from dataclasses import dataclass 18from collections.abc import Sequence 19import io 20import re 21import sys 22 23import yaml 24 25 26@dataclass(frozen=True) 27class Printable: 28 """Common printable object""" 29 30 id: str 31 name: str 32 description: str | None 33 34 @property 35 def variable_name(self) -> str: 36 return "k" + ''.join( 37 ele.title() for ele in re.split(r"[\s_-]+", self.id) 38 ) 39 40 def print(self, writer: io.TextIOWrapper) -> None: 41 writer.write( 42 f""" 43/// @var k{self.variable_name} 44/// @brief {self.name} 45""" 46 ) 47 if self.description: 48 writer.write( 49 f"""/// 50/// {self.description} 51""" 52 ) 53 54 55@dataclass(frozen=True) 56class Units: 57 """A single unit representation""" 58 59 name: str 60 symbol: str 61 62 63@dataclass(frozen=True) 64class Attribute(Printable): 65 """A single attribute representation.""" 66 67 units: Units 68 69 def print(self, writer: io.TextIOWrapper) -> None: 70 """Print header definition for this attribute""" 71 super().print(writer=writer) 72 writer.write( 73 f""" 74PW_SENSOR_ATTRIBUTE_TYPE( 75 static, 76 {super().variable_name}, 77 "PW_SENSOR_ATTRIBUTE_TYPE", 78 "{self.name}", 79 "{self.units.symbol}" 80); 81""" 82 ) 83 84 85@dataclass(frozen=True) 86class Channel(Printable): 87 """A single channel representation.""" 88 89 units: Units 90 91 def print(self, writer: io.TextIOWrapper) -> None: 92 """Print header definition for this channel""" 93 super().print(writer=writer) 94 writer.write( 95 f""" 96PW_SENSOR_MEASUREMENT_TYPE( 97 static, 98 {super().variable_name}, 99 "PW_SENSOR_MEASUREMENT_TYPE", 100 "{self.name}", 101 "{self.units.symbol}" 102); 103""" 104 ) 105 106 107@dataclass(frozen=True) 108class Trigger(Printable): 109 """A single trigger representation.""" 110 111 id: str 112 name: str 113 description: str 114 115 def print(self, writer: io.TextIOWrapper) -> None: 116 """Print header definition for this trigger""" 117 super().print(writer=writer) 118 writer.write( 119 f""" 120PW_SENSOR_TRIGGER_TYPE( 121 static, 122 {super().variable_name}, 123 "PW_SENSOR_TRIGGER_TYPE", 124 "{self.name}" 125); 126""" 127 ) 128 129 130@dataclass 131class Args: 132 """CLI arguments""" 133 134 package: Sequence[str] 135 language: str 136 137 138def attribute_from_dict(attribute_id: str, definition: dict) -> Attribute: 139 """Construct an Attribute from a dictionary entry.""" 140 return Attribute( 141 id=attribute_id, 142 name=definition["name"], 143 description=definition["description"], 144 units=Units( 145 name=definition["units"]["name"], 146 symbol=definition["units"]["symbol"], 147 ), 148 ) 149 150 151def channel_from_dict(channel_id: str, definition: dict) -> Channel: 152 """Construct a Channel from a dictionary entry.""" 153 return Channel( 154 id=channel_id, 155 name=definition["name"], 156 description=definition["description"], 157 units=Units( 158 name=definition["units"]["name"], 159 symbol=definition["units"]["symbol"], 160 ), 161 ) 162 163 164def trigger_from_dict(trigger_id: str, definition: dict) -> Trigger: 165 """Construct a Trigger from a dictionary entry.""" 166 return Trigger( 167 id=trigger_id, 168 name=definition["name"], 169 description=definition["description"], 170 ) 171 172 173class CppHeader: 174 """Generator for a C++ header""" 175 176 def __init__( 177 self, 178 package: Sequence[str], 179 attributes: Sequence[Attribute], 180 channels: Sequence[Channel], 181 triggers: Sequence[Trigger], 182 ) -> None: 183 """ 184 Args: 185 package: The package name used in the output. In C++ we'll convert 186 this to a namespace. 187 units: A sequence of units which should be exposed as 188 ::pw::sensor::MeasurementType. 189 """ 190 self._package: str = '::'.join(package) 191 self._attributes: Sequence[Attribute] = attributes 192 self._channels: Sequence[Channel] = channels 193 self._triggers: Sequence[Trigger] = triggers 194 195 def __str__(self) -> str: 196 writer = io.StringIO() 197 self._print_header(writer=writer) 198 self._print_constants(writer=writer) 199 self._print_footer(writer=writer) 200 return writer.getvalue() 201 202 def _print_header(self, writer: io.TextIOWrapper) -> None: 203 """ 204 Print the top part of the .h file (pragma, includes, and namespace) 205 206 Args: 207 writer: Where to write the text to 208 """ 209 writer.write( 210 "/* Auto-generated file, do not edit */\n" 211 "#pragma once\n" 212 "\n" 213 "#include \"pw_sensor/types.h\"\n" 214 ) 215 if self._package: 216 writer.write(f"namespace {self._package} {{\n\n") 217 218 def _print_constants(self, writer: io.TextIOWrapper) -> None: 219 """ 220 Print the constants definitions from self._attributes, self._channels, 221 and self._trigger 222 223 Args: 224 writer: Where to write the text 225 """ 226 227 writer.write("namespace attributes {\n") 228 for attribute in self._attributes: 229 attribute.print(writer) 230 writer.write("} // namespace attributes\n") 231 writer.write("namespace channels {\n") 232 for channel in self._channels: 233 channel.print(writer) 234 writer.write("} // namespace channels\n") 235 writer.write("namespace triggers {\n") 236 for trigger in self._triggers: 237 trigger.print(writer) 238 writer.write("} // namespace triggers\n") 239 240 def _print_footer(self, writer: io.TextIOWrapper) -> None: 241 """ 242 Write the bottom part of the .h file (closing namespace) 243 244 Args: 245 writer: Where to write the text 246 """ 247 if self._package: 248 writer.write(f"\n}} // namespace {self._package}") 249 250 251def main() -> None: 252 """ 253 Main entry point, this function will: 254 - Get CLI flags 255 - Read YAML from stdin 256 - Find all channel definitions 257 - Print header 258 """ 259 args = get_args() 260 spec = yaml.safe_load(sys.stdin) 261 all_attributes: set[Attribute] = set() 262 all_channels: set[Channel] = set() 263 all_triggers: set[Trigger] = set() 264 for attribute_id, definition in spec["attributes"].items(): 265 attribute = attribute_from_dict( 266 attribute_id=attribute_id, definition=definition 267 ) 268 assert not attribute in all_attributes 269 all_attributes.add(attribute) 270 for channel_id, definition in spec["channels"].items(): 271 channel = channel_from_dict( 272 channel_id=channel_id, definition=definition 273 ) 274 assert not channel in all_channels 275 all_channels.add(channel) 276 for trigger_id, definition in spec["triggers"].items(): 277 trigger = trigger_from_dict( 278 trigger_id=trigger_id, definition=definition 279 ) 280 assert not trigger in all_triggers 281 all_triggers.add(trigger) 282 283 if args.language == "cpp": 284 out = CppHeader( 285 package=args.package, 286 attributes=list(all_attributes), 287 channels=list(all_channels), 288 triggers=list(all_triggers), 289 ) 290 else: 291 raise ValueError(f"Invalid language selected: '{args.language}'") 292 print(out) 293 294 295def validate_package_arg(value: str) -> str: 296 """ 297 Validate that the package argument is a valid string 298 """ 299 if value is None or value == "": 300 return value 301 if not re.match(r"[a-zA-Z_$][\w$]*(\.[a-zA-Z_$][\w$]*)*", value): 302 raise argparse.ArgumentError( 303 argument=None, 304 message=f"Invalid string {value}. Must use alphanumeric values " 305 "separated by dots.", 306 ) 307 return value 308 309 310def get_args() -> Args: 311 """ 312 Get CLI arguments 313 314 Returns: 315 Typed arguments class instance 316 """ 317 parser = argparse.ArgumentParser() 318 parser.add_argument( 319 "--package", 320 "-pkg", 321 default="", 322 type=validate_package_arg, 323 help="Output package name separated by '.', example: com.google", 324 ) 325 parser.add_argument( 326 "--language", 327 type=str, 328 choices=["cpp"], 329 default="cpp", 330 ) 331 args = parser.parse_args() 332 return Args(package=args.package.split("."), language=args.language) 333 334 335if __name__ == "__main__": 336 main() 337