1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tooling to generate C++ constants from a yaml sensor definition.""" 15 16import argparse 17import importlib.resources 18import re 19import sys 20import typing 21from collections.abc import Sequence 22from dataclasses import dataclass, fields, is_dataclass 23from typing import Any 24import types 25 26import jinja2 27import yaml 28 29 30def kid_from_name(name: str) -> str: 31 """Generate a const style ID name from a given name string. 32 33 Example: 34 If name is "sample_rate", the ID would be kSampleRate 35 36 Args: 37 name: the name to convert to an ID 38 39 Returns: 40 C++ style 'k' prefixed camel cased ID 41 42 """ 43 return "k" + "".join(ele.title() for ele in re.split(r"[\s_\-\,]+", name)) 44 45 46@dataclass 47class UnitsSpec: 48 """Typing for the Units definition dictionary.""" 49 50 name: str 51 symbol: str 52 53 54@dataclass 55class AttributeSpec: 56 """Typing for the Attribute definition dictionary.""" 57 58 name: str 59 description: str 60 61 62@dataclass 63class ChannelSpec: 64 """Typing for the Channel definition dictionary.""" 65 66 name: str 67 description: str 68 units: str 69 70 71@dataclass 72class TriggerSpec: 73 """Typing for the Trigger definition dictionary.""" 74 75 name: str 76 description: str 77 78 79@dataclass 80class SensorAttributeSpec: 81 """Typing for the SensorAttribute definition dictionary.""" 82 83 channel: str | None 84 trigger: str | None 85 attribute: str 86 units: str 87 88 89@dataclass 90class CompatibleSpec: 91 """Typing for the Compatible dictionary.""" 92 93 org: str | None 94 part: str 95 96 97@dataclass 98class SensorSpec: 99 """Typing for the Sensor definition dictionary.""" 100 101 description: str 102 compatible: CompatibleSpec 103 supported_buses: list[str] 104 attributes: list[SensorAttributeSpec] 105 channels: dict[str, list[ChannelSpec]] 106 triggers: list[Any] 107 extras: dict[str, Any] 108 109 110@dataclass 111class Args: 112 """CLI arguments""" 113 114 package: Sequence[str] 115 language: str 116 zephyr: bool 117 118 119@dataclass 120class InputSpec: 121 """Typing for the InputData spec dictionary""" 122 123 units: dict[str, UnitsSpec] 124 attributes: dict[str, AttributeSpec] 125 channels: dict[str, ChannelSpec] 126 triggers: dict[str, TriggerSpec] 127 sensors: dict[str, SensorSpec] 128 129 130def is_list_type(t: Any) -> bool: # noqa: ANN401 131 """ 132 Checks if the given type `t` is a list. 133 134 Args: 135 t: The type to check. 136 137 Returns: 138 True if `t` is a list type, False otherwise. 139 140 """ 141 origin = typing.get_origin(t) 142 return origin is list or (origin is list and typing.get_args(t) == ()) 143 144 145def is_primitive(value: Any) -> bool: # noqa: ANN401 146 """Checks if the given value is of a primitive type. 147 148 Args: 149 value: The value to check. 150 151 Returns: 152 True if the value is of a primitive type, False otherwise. 153 154 """ 155 return isinstance(value, int | float | complex | str | bool) 156 157 158def is_union(t: Any) -> bool: # noqa: ANN401 159 """Check if the given type is a union 160 161 Args: 162 t: The type to check. 163 164 Returns: 165 True if `t` is a union type, False otherwise. 166 167 """ 168 return ( 169 typing.get_origin(t) is typing.Union 170 or typing.get_origin(t) is types.UnionType 171 ) 172 173 174def create_dataclass_from_dict( 175 cls: Any, 176 data: Any, 177 indent: int = 0, # noqa: ANN401 178) -> Any: # noqa: ANN401 179 """Recursively creates a dataclass instance from a nested dictionary.""" 180 field_values: dict[str, Any] = {} 181 182 if is_list_type(cls): 183 result = [] 184 for item in data: 185 result.append( # noqa: PERF401 186 create_dataclass_from_dict( 187 typing.get_args(cls)[0], item, indent + 2 188 ) 189 ) 190 return result 191 192 if is_primitive(data): 193 return data 194 195 for field in fields(cls): 196 field_value = data.get(field.name) 197 if field_value is None: 198 field_value = data.get(field.name.replace("_", "-")) 199 200 if ( 201 is_union(field.type) 202 and type(None) in typing.get_args(field.type) 203 and field_value is None 204 ): 205 # We have an optional field and no value, skip it 206 field_values[field.name] = None 207 continue 208 209 assert field_value is not None 210 211 # We need to check if the field is a List, dictionary, or another 212 # dataclass. If it is, recurse. 213 if is_list_type(field.type): 214 item_type = typing.get_args(field.type)[0] 215 field_value = [ 216 create_dataclass_from_dict(item_type, item, indent + 2) 217 for item in field_value 218 ] 219 elif dict in getattr(field.type, "__mro__", []): 220 # We might not have types specified in the dataclass 221 value_types = typing.get_args(field.type) 222 if len(value_types) != 0: 223 value_type = value_types[1] 224 field_value = { 225 key: create_dataclass_from_dict(value_type, val, indent + 2) 226 for key, val in field_value.items() 227 } 228 elif is_dataclass(field.type): 229 field_value = create_dataclass_from_dict( 230 field.type, field_value, indent + 2 231 ) 232 233 field_values[field.name] = field_value 234 235 return cls(**field_values) 236 237 238def main() -> None: 239 """Main entry point 240 241 This function will: 242 - Get CLI flags 243 - Read YAML from stdin 244 - Find all attribute, channel, trigger, and unit definitions 245 - Print header 246 """ 247 args = get_args() 248 yaml_input = yaml.safe_load(sys.stdin) 249 spec: InputSpec = create_dataclass_from_dict(InputSpec, yaml_input) 250 251 jinja_templates = { 252 t: importlib.resources.read_text("pw_sensor.templates", t) 253 for t in importlib.resources.contents("pw_sensor.templates") 254 if t.endswith(".jinja") 255 } 256 environment = jinja2.Environment( 257 loader=jinja2.DictLoader(jinja_templates), 258 autoescape=True, 259 # Trim whitespace in templates 260 trim_blocks=True, 261 lstrip_blocks=True, 262 ) 263 environment.globals["kid_from_name"] = kid_from_name 264 265 if args.language == "cpp": 266 template = environment.get_template("cpp_constants.jinja") 267 out = template.render( 268 { 269 "spec": spec, 270 "package_name": "::".join(args.package), 271 } 272 ) 273 else: 274 error = f"Invalid language selected: '{args.language}'" 275 raise ValueError(error) 276 277 sys.stdout.write(out) 278 279 280def validate_package_arg(value: str) -> str: 281 """ 282 Validate that the package argument is a valid string 283 284 Args: 285 value: The package name 286 287 Returns: 288 The same value after being validated. 289 290 """ 291 if value is None or value == "": 292 return value 293 if not re.match(r"[a-zA-Z_$][\w$]*(\.[a-zA-Z_$][\w$]*)*", value): 294 raise argparse.ArgumentError( 295 argument=None, 296 message=f"Invalid string {value}. Must use alphanumeric values " 297 "separated by dots.", 298 ) 299 return value 300 301 302def get_args() -> Args: 303 """ 304 Get CLI arguments 305 306 Returns: 307 Typed arguments class instance 308 309 """ 310 parser = argparse.ArgumentParser() 311 parser.add_argument( 312 "--package", 313 "-pkg", 314 default="", 315 type=validate_package_arg, 316 help="Output package name separated by '.', example: com.google", 317 ) 318 parser.add_argument( 319 "--language", 320 type=str, 321 choices=["cpp"], 322 default="cpp", 323 ) 324 parser.add_argument( 325 "--zephyr", 326 action="store_true", 327 ) 328 args = parser.parse_args() 329 return Args( 330 package=args.package.split("."), 331 language=args.language, 332 zephyr=args.zephyr, 333 ) 334 335 336if __name__ == "__main__": 337 main() 338