1#!/usr/bin/python 2 3'''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.''' 4nanopb_version = "nanopb-0.2.8-dev" 5 6import sys 7 8try: 9 # Add some dummy imports to keep packaging tools happy. 10 import google, distutils.util # bbfreeze seems to need these 11 import pkg_resources # pyinstaller / protobuf 2.5 seem to need these 12except: 13 # Don't care, we will error out later if it is actually important. 14 pass 15 16try: 17 import google.protobuf.text_format as text_format 18 import google.protobuf.descriptor_pb2 as descriptor 19except: 20 sys.stderr.write(''' 21 ************************************************************* 22 *** Could not import the Google protobuf Python libraries *** 23 *** Try installing package 'python-protobuf' or similar. *** 24 ************************************************************* 25 ''' + '\n') 26 raise 27 28try: 29 import proto.nanopb_pb2 as nanopb_pb2 30 import proto.plugin_pb2 as plugin_pb2 31except: 32 sys.stderr.write(''' 33 ******************************************************************** 34 *** Failed to import the protocol definitions for generator. *** 35 *** You have to run 'make' in the nanopb/generator/proto folder. *** 36 ******************************************************************** 37 ''' + '\n') 38 raise 39 40# --------------------------------------------------------------------------- 41# Generation of single fields 42# --------------------------------------------------------------------------- 43 44import time 45import os.path 46 47# Values are tuple (c type, pb type, encoded size) 48FieldD = descriptor.FieldDescriptorProto 49datatypes = { 50 FieldD.TYPE_BOOL: ('bool', 'BOOL', 1), 51 FieldD.TYPE_DOUBLE: ('double', 'DOUBLE', 8), 52 FieldD.TYPE_FIXED32: ('uint32_t', 'FIXED32', 4), 53 FieldD.TYPE_FIXED64: ('uint64_t', 'FIXED64', 8), 54 FieldD.TYPE_FLOAT: ('float', 'FLOAT', 4), 55 FieldD.TYPE_INT32: ('int32_t', 'INT32', 10), 56 FieldD.TYPE_INT64: ('int64_t', 'INT64', 10), 57 FieldD.TYPE_SFIXED32: ('int32_t', 'SFIXED32', 4), 58 FieldD.TYPE_SFIXED64: ('int64_t', 'SFIXED64', 8), 59 FieldD.TYPE_SINT32: ('int32_t', 'SINT32', 5), 60 FieldD.TYPE_SINT64: ('int64_t', 'SINT64', 10), 61 FieldD.TYPE_UINT32: ('uint32_t', 'UINT32', 5), 62 FieldD.TYPE_UINT64: ('uint64_t', 'UINT64', 10) 63} 64 65class Names: 66 '''Keeps a set of nested names and formats them to C identifier.''' 67 def __init__(self, parts = ()): 68 if isinstance(parts, Names): 69 parts = parts.parts 70 self.parts = tuple(parts) 71 72 def __str__(self): 73 return '_'.join(self.parts) 74 75 def __add__(self, other): 76 if isinstance(other, (str, unicode)): 77 return Names(self.parts + (other,)) 78 elif isinstance(other, tuple): 79 return Names(self.parts + other) 80 else: 81 raise ValueError("Name parts should be of type str") 82 83 def __eq__(self, other): 84 return isinstance(other, Names) and self.parts == other.parts 85 86def names_from_type_name(type_name): 87 '''Parse Names() from FieldDescriptorProto type_name''' 88 if type_name[0] != '.': 89 raise NotImplementedError("Lookup of non-absolute type names is not supported") 90 return Names(type_name[1:].split('.')) 91 92def varint_max_size(max_value): 93 '''Returns the maximum number of bytes a varint can take when encoded.''' 94 for i in range(1, 11): 95 if (max_value >> (i * 7)) == 0: 96 return i 97 raise ValueError("Value too large for varint: " + str(max_value)) 98 99assert varint_max_size(0) == 1 100assert varint_max_size(127) == 1 101assert varint_max_size(128) == 2 102 103class EncodedSize: 104 '''Class used to represent the encoded size of a field or a message. 105 Consists of a combination of symbolic sizes and integer sizes.''' 106 def __init__(self, value = 0, symbols = []): 107 if isinstance(value, (str, Names)): 108 symbols = [str(value)] 109 value = 0 110 self.value = value 111 self.symbols = symbols 112 113 def __add__(self, other): 114 if isinstance(other, (int, long)): 115 return EncodedSize(self.value + other, self.symbols) 116 elif isinstance(other, (str, Names)): 117 return EncodedSize(self.value, self.symbols + [str(other)]) 118 elif isinstance(other, EncodedSize): 119 return EncodedSize(self.value + other.value, self.symbols + other.symbols) 120 else: 121 raise ValueError("Cannot add size: " + repr(other)) 122 123 def __mul__(self, other): 124 if isinstance(other, (int, long)): 125 return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols]) 126 else: 127 raise ValueError("Cannot multiply size: " + repr(other)) 128 129 def __str__(self): 130 if not self.symbols: 131 return str(self.value) 132 else: 133 return '(' + str(self.value) + ' + ' + ' + '.join(self.symbols) + ')' 134 135 def upperlimit(self): 136 if not self.symbols: 137 return self.value 138 else: 139 return 2**32 - 1 140 141class Enum: 142 def __init__(self, names, desc, enum_options): 143 '''desc is EnumDescriptorProto''' 144 145 self.options = enum_options 146 self.names = names + desc.name 147 148 if enum_options.long_names: 149 self.values = [(self.names + x.name, x.number) for x in desc.value] 150 else: 151 self.values = [(names + x.name, x.number) for x in desc.value] 152 153 self.value_longnames = [self.names + x.name for x in desc.value] 154 155 def __str__(self): 156 result = 'typedef enum _%s {\n' % self.names 157 result += ',\n'.join([" %s = %d" % x for x in self.values]) 158 result += '\n} %s;' % self.names 159 return result 160 161class Field: 162 def __init__(self, struct_name, desc, field_options): 163 '''desc is FieldDescriptorProto''' 164 self.tag = desc.number 165 self.struct_name = struct_name 166 self.name = desc.name 167 self.default = None 168 self.max_size = None 169 self.max_count = None 170 self.array_decl = "" 171 self.enc_size = None 172 self.ctype = None 173 174 # Parse field options 175 if field_options.HasField("max_size"): 176 self.max_size = field_options.max_size 177 178 if field_options.HasField("max_count"): 179 self.max_count = field_options.max_count 180 181 if desc.HasField('default_value'): 182 self.default = desc.default_value 183 184 # Check field rules, i.e. required/optional/repeated. 185 can_be_static = True 186 if desc.label == FieldD.LABEL_REQUIRED: 187 self.rules = 'REQUIRED' 188 elif desc.label == FieldD.LABEL_OPTIONAL: 189 self.rules = 'OPTIONAL' 190 elif desc.label == FieldD.LABEL_REPEATED: 191 self.rules = 'REPEATED' 192 if self.max_count is None: 193 can_be_static = False 194 else: 195 self.array_decl = '[%d]' % self.max_count 196 else: 197 raise NotImplementedError(desc.label) 198 199 # Check if the field can be implemented with static allocation 200 # i.e. whether the data size is known. 201 if desc.type == FieldD.TYPE_STRING and self.max_size is None: 202 can_be_static = False 203 204 if desc.type == FieldD.TYPE_BYTES and self.max_size is None: 205 can_be_static = False 206 207 # Decide how the field data will be allocated 208 if field_options.type == nanopb_pb2.FT_DEFAULT: 209 if can_be_static: 210 field_options.type = nanopb_pb2.FT_STATIC 211 else: 212 field_options.type = nanopb_pb2.FT_CALLBACK 213 214 if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static: 215 raise Exception("Field %s is defined as static, but max_size or " 216 "max_count is not given." % self.name) 217 218 if field_options.type == nanopb_pb2.FT_STATIC: 219 self.allocation = 'STATIC' 220 elif field_options.type == nanopb_pb2.FT_POINTER: 221 self.allocation = 'POINTER' 222 elif field_options.type == nanopb_pb2.FT_CALLBACK: 223 self.allocation = 'CALLBACK' 224 else: 225 raise NotImplementedError(field_options.type) 226 227 # Decide the C data type to use in the struct. 228 if datatypes.has_key(desc.type): 229 self.ctype, self.pbtype, self.enc_size = datatypes[desc.type] 230 elif desc.type == FieldD.TYPE_ENUM: 231 self.pbtype = 'ENUM' 232 self.ctype = names_from_type_name(desc.type_name) 233 if self.default is not None: 234 self.default = self.ctype + self.default 235 self.enc_size = 5 # protoc rejects enum values > 32 bits 236 elif desc.type == FieldD.TYPE_STRING: 237 self.pbtype = 'STRING' 238 self.ctype = 'char' 239 if self.allocation == 'STATIC': 240 self.ctype = 'char' 241 self.array_decl += '[%d]' % self.max_size 242 self.enc_size = varint_max_size(self.max_size) + self.max_size 243 elif desc.type == FieldD.TYPE_BYTES: 244 self.pbtype = 'BYTES' 245 if self.allocation == 'STATIC': 246 self.ctype = self.struct_name + self.name + 't' 247 self.enc_size = varint_max_size(self.max_size) + self.max_size 248 elif self.allocation == 'POINTER': 249 self.ctype = 'pb_bytes_array_t' 250 elif desc.type == FieldD.TYPE_MESSAGE: 251 self.pbtype = 'MESSAGE' 252 self.ctype = self.submsgname = names_from_type_name(desc.type_name) 253 self.enc_size = None # Needs to be filled in after the message type is available 254 else: 255 raise NotImplementedError(desc.type) 256 257 def __cmp__(self, other): 258 return cmp(self.tag, other.tag) 259 260 def __str__(self): 261 result = '' 262 if self.allocation == 'POINTER': 263 if self.rules == 'REPEATED': 264 result += ' size_t ' + self.name + '_count;\n' 265 266 if self.pbtype == 'MESSAGE': 267 # Use struct definition, so recursive submessages are possible 268 result += ' struct _%s *%s;' % (self.ctype, self.name) 269 elif self.rules == 'REPEATED' and self.pbtype in ['STRING', 'BYTES']: 270 # String/bytes arrays need to be defined as pointers to pointers 271 result += ' %s **%s;' % (self.ctype, self.name) 272 else: 273 result += ' %s *%s;' % (self.ctype, self.name) 274 elif self.allocation == 'CALLBACK': 275 result += ' pb_callback_t %s;' % self.name 276 else: 277 if self.rules == 'OPTIONAL' and self.allocation == 'STATIC': 278 result += ' bool has_' + self.name + ';\n' 279 elif self.rules == 'REPEATED' and self.allocation == 'STATIC': 280 result += ' size_t ' + self.name + '_count;\n' 281 result += ' %s %s%s;' % (self.ctype, self.name, self.array_decl) 282 return result 283 284 def types(self): 285 '''Return definitions for any special types this field might need.''' 286 if self.pbtype == 'BYTES' and self.allocation == 'STATIC': 287 result = 'typedef struct {\n' 288 result += ' size_t size;\n' 289 result += ' uint8_t bytes[%d];\n' % self.max_size 290 result += '} %s;\n' % self.ctype 291 else: 292 result = None 293 return result 294 295 def default_decl(self, declaration_only = False): 296 '''Return definition for this field's default value.''' 297 if self.default is None: 298 return None 299 300 ctype, default = self.ctype, self.default 301 array_decl = '' 302 303 if self.pbtype == 'STRING': 304 if self.allocation != 'STATIC': 305 return None # Not implemented 306 307 array_decl = '[%d]' % self.max_size 308 default = str(self.default).encode('string_escape') 309 default = default.replace('"', '\\"') 310 default = '"' + default + '"' 311 elif self.pbtype == 'BYTES': 312 if self.allocation != 'STATIC': 313 return None # Not implemented 314 315 data = self.default.decode('string_escape') 316 data = ['0x%02x' % ord(c) for c in data] 317 default = '{%d, {%s}}' % (len(data), ','.join(data)) 318 elif self.pbtype in ['FIXED32', 'UINT32']: 319 default += 'u' 320 elif self.pbtype in ['FIXED64', 'UINT64']: 321 default += 'ull' 322 elif self.pbtype in ['SFIXED64', 'INT64']: 323 default += 'll' 324 325 if declaration_only: 326 return 'extern const %s %s_default%s;' % (ctype, self.struct_name + self.name, array_decl) 327 else: 328 return 'const %s %s_default%s = %s;' % (ctype, self.struct_name + self.name, array_decl, default) 329 330 def tags(self): 331 '''Return the #define for the tag number of this field.''' 332 identifier = '%s_%s_tag' % (self.struct_name, self.name) 333 return '#define %-40s %d\n' % (identifier, self.tag) 334 335 def pb_field_t(self, prev_field_name): 336 '''Return the pb_field_t initializer to use in the constant array. 337 prev_field_name is the name of the previous field or None. 338 ''' 339 result = ' PB_FIELD2(%3d, ' % self.tag 340 result += '%-8s, ' % self.pbtype 341 result += '%s, ' % self.rules 342 result += '%-8s, ' % self.allocation 343 result += '%s, ' % ("FIRST" if not prev_field_name else "OTHER") 344 result += '%s, ' % self.struct_name 345 result += '%s, ' % self.name 346 result += '%s, ' % (prev_field_name or self.name) 347 348 if self.pbtype == 'MESSAGE': 349 result += '&%s_fields)' % self.submsgname 350 elif self.default is None: 351 result += '0)' 352 elif self.pbtype in ['BYTES', 'STRING'] and self.allocation != 'STATIC': 353 result += '0)' # Arbitrary size default values not implemented 354 elif self.rules == 'OPTEXT': 355 result += '0)' # Default value for extensions is not implemented 356 else: 357 result += '&%s_default)' % (self.struct_name + self.name) 358 359 return result 360 361 def largest_field_value(self): 362 '''Determine if this field needs 16bit or 32bit pb_field_t structure to compile properly. 363 Returns numeric value or a C-expression for assert.''' 364 if self.pbtype == 'MESSAGE': 365 if self.rules == 'REPEATED' and self.allocation == 'STATIC': 366 return 'pb_membersize(%s, %s[0])' % (self.struct_name, self.name) 367 else: 368 return 'pb_membersize(%s, %s)' % (self.struct_name, self.name) 369 370 return max(self.tag, self.max_size, self.max_count) 371 372 def encoded_size(self, allmsgs): 373 '''Return the maximum size that this field can take when encoded, 374 including the field tag. If the size cannot be determined, returns 375 None.''' 376 377 if self.allocation != 'STATIC': 378 return None 379 380 if self.pbtype == 'MESSAGE': 381 for msg in allmsgs: 382 if msg.name == self.submsgname: 383 encsize = msg.encoded_size(allmsgs) 384 if encsize is None: 385 return None # Submessage size is indeterminate 386 387 # Include submessage length prefix 388 encsize += varint_max_size(encsize.upperlimit()) 389 break 390 else: 391 # Submessage cannot be found, this currently occurs when 392 # the submessage type is defined in a different file. 393 # Instead of direct numeric value, reference the size that 394 # has been #defined in the other file. 395 encsize = EncodedSize(self.submsgname + 'size') 396 397 # We will have to make a conservative assumption on the length 398 # prefix size, though. 399 encsize += 5 400 401 elif self.enc_size is None: 402 raise RuntimeError("Could not determine encoded size for %s.%s" 403 % (self.struct_name, self.name)) 404 else: 405 encsize = EncodedSize(self.enc_size) 406 407 encsize += varint_max_size(self.tag << 3) # Tag + wire type 408 409 if self.rules == 'REPEATED': 410 # Decoders must be always able to handle unpacked arrays. 411 # Therefore we have to reserve space for it, even though 412 # we emit packed arrays ourselves. 413 encsize *= self.max_count 414 415 return encsize 416 417 418class ExtensionRange(Field): 419 def __init__(self, struct_name, range_start, field_options): 420 '''Implements a special pb_extension_t* field in an extensible message 421 structure. The range_start signifies the index at which the extensions 422 start. Not necessarily all tags above this are extensions, it is merely 423 a speed optimization. 424 ''' 425 self.tag = range_start 426 self.struct_name = struct_name 427 self.name = 'extensions' 428 self.pbtype = 'EXTENSION' 429 self.rules = 'OPTIONAL' 430 self.allocation = 'CALLBACK' 431 self.ctype = 'pb_extension_t' 432 self.array_decl = '' 433 self.default = None 434 self.max_size = 0 435 self.max_count = 0 436 437 def __str__(self): 438 return ' pb_extension_t *extensions;' 439 440 def types(self): 441 return None 442 443 def tags(self): 444 return '' 445 446 def encoded_size(self, allmsgs): 447 # We exclude extensions from the count, because they cannot be known 448 # until runtime. Other option would be to return None here, but this 449 # way the value remains useful if extensions are not used. 450 return EncodedSize(0) 451 452class ExtensionField(Field): 453 def __init__(self, struct_name, desc, field_options): 454 self.fullname = struct_name + desc.name 455 self.extendee_name = names_from_type_name(desc.extendee) 456 Field.__init__(self, self.fullname + 'struct', desc, field_options) 457 458 if self.rules != 'OPTIONAL': 459 self.skip = True 460 else: 461 self.skip = False 462 self.rules = 'OPTEXT' 463 464 def tags(self): 465 '''Return the #define for the tag number of this field.''' 466 identifier = '%s_tag' % self.fullname 467 return '#define %-40s %d\n' % (identifier, self.tag) 468 469 def extension_decl(self): 470 '''Declaration of the extension type in the .pb.h file''' 471 if self.skip: 472 msg = '/* Extension field %s was skipped because only "optional"\n' % self.fullname 473 msg +=' type of extension fields is currently supported. */\n' 474 return msg 475 476 return 'extern const pb_extension_type_t %s;\n' % self.fullname 477 478 def extension_def(self): 479 '''Definition of the extension type in the .pb.c file''' 480 481 if self.skip: 482 return '' 483 484 result = 'typedef struct {\n' 485 result += str(self) 486 result += '\n} %s;\n\n' % self.struct_name 487 result += ('static const pb_field_t %s_field = \n %s;\n\n' % 488 (self.fullname, self.pb_field_t(None))) 489 result += 'const pb_extension_type_t %s = {\n' % self.fullname 490 result += ' NULL,\n' 491 result += ' NULL,\n' 492 result += ' &%s_field\n' % self.fullname 493 result += '};\n' 494 return result 495 496 497# --------------------------------------------------------------------------- 498# Generation of messages (structures) 499# --------------------------------------------------------------------------- 500 501 502class Message: 503 def __init__(self, names, desc, message_options): 504 self.name = names 505 self.fields = [] 506 507 for f in desc.field: 508 field_options = get_nanopb_suboptions(f, message_options, self.name + f.name) 509 if field_options.type != nanopb_pb2.FT_IGNORE: 510 self.fields.append(Field(self.name, f, field_options)) 511 512 if len(desc.extension_range) > 0: 513 field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions') 514 range_start = min([r.start for r in desc.extension_range]) 515 if field_options.type != nanopb_pb2.FT_IGNORE: 516 self.fields.append(ExtensionRange(self.name, range_start, field_options)) 517 518 self.packed = message_options.packed_struct 519 self.ordered_fields = self.fields[:] 520 self.ordered_fields.sort() 521 522 def get_dependencies(self): 523 '''Get list of type names that this structure refers to.''' 524 return [str(field.ctype) for field in self.fields] 525 526 def __str__(self): 527 result = 'typedef struct _%s {\n' % self.name 528 529 if not self.ordered_fields: 530 # Empty structs are not allowed in C standard. 531 # Therefore add a dummy field if an empty message occurs. 532 result += ' uint8_t dummy_field;' 533 534 result += '\n'.join([str(f) for f in self.ordered_fields]) 535 result += '\n}' 536 537 if self.packed: 538 result += ' pb_packed' 539 540 result += ' %s;' % self.name 541 542 if self.packed: 543 result = 'PB_PACKED_STRUCT_START\n' + result 544 result += '\nPB_PACKED_STRUCT_END' 545 546 return result 547 548 def types(self): 549 result = "" 550 for field in self.fields: 551 types = field.types() 552 if types is not None: 553 result += types + '\n' 554 return result 555 556 def default_decl(self, declaration_only = False): 557 result = "" 558 for field in self.fields: 559 default = field.default_decl(declaration_only) 560 if default is not None: 561 result += default + '\n' 562 return result 563 564 def fields_declaration(self): 565 result = 'extern const pb_field_t %s_fields[%d];' % (self.name, len(self.fields) + 1) 566 return result 567 568 def fields_definition(self): 569 result = 'const pb_field_t %s_fields[%d] = {\n' % (self.name, len(self.fields) + 1) 570 571 prev = None 572 for field in self.ordered_fields: 573 result += field.pb_field_t(prev) 574 result += ',\n' 575 prev = field.name 576 577 result += ' PB_LAST_FIELD\n};' 578 return result 579 580 def encoded_size(self, allmsgs): 581 '''Return the maximum size that this message can take when encoded. 582 If the size cannot be determined, returns None. 583 ''' 584 size = EncodedSize(0) 585 for field in self.fields: 586 fsize = field.encoded_size(allmsgs) 587 if fsize is None: 588 return None 589 size += fsize 590 591 return size 592 593 594# --------------------------------------------------------------------------- 595# Processing of entire .proto files 596# --------------------------------------------------------------------------- 597 598 599def iterate_messages(desc, names = Names()): 600 '''Recursively find all messages. For each, yield name, DescriptorProto.''' 601 if hasattr(desc, 'message_type'): 602 submsgs = desc.message_type 603 else: 604 submsgs = desc.nested_type 605 606 for submsg in submsgs: 607 sub_names = names + submsg.name 608 yield sub_names, submsg 609 610 for x in iterate_messages(submsg, sub_names): 611 yield x 612 613def iterate_extensions(desc, names = Names()): 614 '''Recursively find all extensions. 615 For each, yield name, FieldDescriptorProto. 616 ''' 617 for extension in desc.extension: 618 yield names, extension 619 620 for subname, subdesc in iterate_messages(desc, names): 621 for extension in subdesc.extension: 622 yield subname, extension 623 624def parse_file(fdesc, file_options): 625 '''Takes a FileDescriptorProto and returns tuple (enums, messages, extensions).''' 626 627 enums = [] 628 messages = [] 629 extensions = [] 630 631 if fdesc.package: 632 base_name = Names(fdesc.package.split('.')) 633 else: 634 base_name = Names() 635 636 for enum in fdesc.enum_type: 637 enum_options = get_nanopb_suboptions(enum, file_options, base_name + enum.name) 638 enums.append(Enum(base_name, enum, enum_options)) 639 640 for names, message in iterate_messages(fdesc, base_name): 641 message_options = get_nanopb_suboptions(message, file_options, names) 642 messages.append(Message(names, message, message_options)) 643 for enum in message.enum_type: 644 enum_options = get_nanopb_suboptions(enum, message_options, names + enum.name) 645 enums.append(Enum(names, enum, enum_options)) 646 647 for names, extension in iterate_extensions(fdesc, base_name): 648 field_options = get_nanopb_suboptions(extension, file_options, names) 649 if field_options.type != nanopb_pb2.FT_IGNORE: 650 extensions.append(ExtensionField(names, extension, field_options)) 651 652 # Fix field default values where enum short names are used. 653 for enum in enums: 654 if not enum.options.long_names: 655 for message in messages: 656 for field in message.fields: 657 if field.default in enum.value_longnames: 658 idx = enum.value_longnames.index(field.default) 659 field.default = enum.values[idx][0] 660 661 return enums, messages, extensions 662 663def toposort2(data): 664 '''Topological sort. 665 From http://code.activestate.com/recipes/577413-topological-sort/ 666 This function is under the MIT license. 667 ''' 668 for k, v in data.items(): 669 v.discard(k) # Ignore self dependencies 670 extra_items_in_deps = reduce(set.union, data.values(), set()) - set(data.keys()) 671 data.update(dict([(item, set()) for item in extra_items_in_deps])) 672 while True: 673 ordered = set(item for item,dep in data.items() if not dep) 674 if not ordered: 675 break 676 for item in sorted(ordered): 677 yield item 678 data = dict([(item, (dep - ordered)) for item,dep in data.items() 679 if item not in ordered]) 680 assert not data, "A cyclic dependency exists amongst %r" % data 681 682def sort_dependencies(messages): 683 '''Sort a list of Messages based on dependencies.''' 684 dependencies = {} 685 message_by_name = {} 686 for message in messages: 687 dependencies[str(message.name)] = set(message.get_dependencies()) 688 message_by_name[str(message.name)] = message 689 690 for msgname in toposort2(dependencies): 691 if msgname in message_by_name: 692 yield message_by_name[msgname] 693 694def make_identifier(headername): 695 '''Make #ifndef identifier that contains uppercase A-Z and digits 0-9''' 696 result = "" 697 for c in headername.upper(): 698 if c.isalnum(): 699 result += c 700 else: 701 result += '_' 702 return result 703 704def generate_header(dependencies, headername, enums, messages, extensions, options): 705 '''Generate content for a header file. 706 Generates strings, which should be concatenated and stored to file. 707 ''' 708 709 yield '/* Automatically generated nanopb header */\n' 710 if options.notimestamp: 711 yield '/* Generated by %s */\n\n' % (nanopb_version) 712 else: 713 yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) 714 715 symbol = make_identifier(headername) 716 yield '#ifndef _PB_%s_\n' % symbol 717 yield '#define _PB_%s_\n' % symbol 718 try: 719 yield options.libformat % ('pb.h') 720 except TypeError: 721 # no %s specified - use whatever was passed in as options.libformat 722 yield options.libformat 723 yield '\n' 724 725 for dependency in dependencies: 726 noext = os.path.splitext(dependency)[0] 727 yield options.genformat % (noext + '.' + options.extension + '.h') 728 yield '\n' 729 730 yield '#ifdef __cplusplus\n' 731 yield 'extern "C" {\n' 732 yield '#endif\n\n' 733 734 yield '/* Enum definitions */\n' 735 for enum in enums: 736 yield str(enum) + '\n\n' 737 738 yield '/* Struct definitions */\n' 739 for msg in sort_dependencies(messages): 740 yield msg.types() 741 yield str(msg) + '\n\n' 742 743 if extensions: 744 yield '/* Extensions */\n' 745 for extension in extensions: 746 yield extension.extension_decl() 747 yield '\n' 748 749 yield '/* Default values for struct fields */\n' 750 for msg in messages: 751 yield msg.default_decl(True) 752 yield '\n' 753 754 yield '/* Field tags (for use in manual encoding/decoding) */\n' 755 for msg in sort_dependencies(messages): 756 for field in msg.fields: 757 yield field.tags() 758 for extension in extensions: 759 yield extension.tags() 760 yield '\n' 761 762 yield '/* Struct field encoding specification for nanopb */\n' 763 for msg in messages: 764 yield msg.fields_declaration() + '\n' 765 yield '\n' 766 767 yield '/* Maximum encoded size of messages (where known) */\n' 768 for msg in messages: 769 msize = msg.encoded_size(messages) 770 if msize is not None: 771 identifier = '%s_size' % msg.name 772 yield '#define %-40s %s\n' % (identifier, msize) 773 yield '\n' 774 775 yield '#ifdef __cplusplus\n' 776 yield '} /* extern "C" */\n' 777 yield '#endif\n' 778 779 # End of header 780 yield '\n#endif\n' 781 782def generate_source(headername, enums, messages, extensions, options): 783 '''Generate content for a source file.''' 784 785 yield '/* Automatically generated nanopb constant definitions */\n' 786 if options.notimestamp: 787 yield '/* Generated by %s */\n\n' % (nanopb_version) 788 else: 789 yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) 790 yield options.genformat % (headername) 791 yield '\n' 792 793 for msg in messages: 794 yield msg.default_decl(False) 795 796 yield '\n\n' 797 798 for msg in messages: 799 yield msg.fields_definition() + '\n\n' 800 801 for ext in extensions: 802 yield ext.extension_def() + '\n' 803 804 # Add checks for numeric limits 805 if messages: 806 count_required_fields = lambda m: len([f for f in msg.fields if f.rules == 'REQUIRED']) 807 largest_msg = max(messages, key = count_required_fields) 808 largest_count = count_required_fields(largest_msg) 809 if largest_count > 64: 810 yield '\n/* Check that missing required fields will be properly detected */\n' 811 yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count 812 yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name 813 yield ' setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count 814 yield '#endif\n' 815 816 worst = 0 817 worst_field = '' 818 checks = [] 819 checks_msgnames = [] 820 for msg in messages: 821 checks_msgnames.append(msg.name) 822 for field in msg.fields: 823 status = field.largest_field_value() 824 if isinstance(status, (str, unicode)): 825 checks.append(status) 826 elif status > worst: 827 worst = status 828 worst_field = str(field.struct_name) + '.' + str(field.name) 829 830 if worst > 255 or checks: 831 yield '\n/* Check that field information fits in pb_field_t */\n' 832 833 if worst > 65535 or checks: 834 yield '#if !defined(PB_FIELD_32BIT)\n' 835 if worst > 65535: 836 yield '#error Field descriptor for %s is too large. Define PB_FIELD_32BIT to fix this.\n' % worst_field 837 else: 838 assertion = ' && '.join(str(c) + ' < 65536' for c in checks) 839 msgs = '_'.join(str(n) for n in checks_msgnames) 840 yield '/* If you get an error here, it means that you need to define PB_FIELD_32BIT\n' 841 yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n' 842 yield ' * \n' 843 yield ' * The reason you need to do this is that some of your messages contain tag\n' 844 yield ' * numbers or field sizes that are larger than what can fit in 8 or 16 bit\n' 845 yield ' * field descriptors.\n' 846 yield ' */\n' 847 yield 'STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs) 848 yield '#endif\n\n' 849 850 if worst < 65536: 851 yield '#if !defined(PB_FIELD_16BIT) && !defined(PB_FIELD_32BIT)\n' 852 if worst > 255: 853 yield '#error Field descriptor for %s is too large. Define PB_FIELD_16BIT to fix this.\n' % worst_field 854 else: 855 assertion = ' && '.join(str(c) + ' < 256' for c in checks) 856 msgs = '_'.join(str(n) for n in checks_msgnames) 857 yield '/* If you get an error here, it means that you need to define PB_FIELD_16BIT\n' 858 yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n' 859 yield ' * \n' 860 yield ' * The reason you need to do this is that some of your messages contain tag\n' 861 yield ' * numbers or field sizes that are larger than what can fit in the default\n' 862 yield ' * 8 bit descriptors.\n' 863 yield ' */\n' 864 yield 'STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs) 865 yield '#endif\n\n' 866 867 # Add check for sizeof(double) 868 has_double = False 869 for msg in messages: 870 for field in msg.fields: 871 if field.ctype == 'double': 872 has_double = True 873 874 if has_double: 875 yield '\n' 876 yield '/* On some platforms (such as AVR), double is really float.\n' 877 yield ' * These are not directly supported by nanopb, but see example_avr_double.\n' 878 yield ' * To get rid of this error, remove any double fields from your .proto.\n' 879 yield ' */\n' 880 yield 'STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n' 881 882 yield '\n' 883 884# --------------------------------------------------------------------------- 885# Options parsing for the .proto files 886# --------------------------------------------------------------------------- 887 888from fnmatch import fnmatch 889 890def read_options_file(infile): 891 '''Parse a separate options file to list: 892 [(namemask, options), ...] 893 ''' 894 results = [] 895 for line in infile: 896 line = line.strip() 897 if not line or line.startswith('//') or line.startswith('#'): 898 continue 899 900 parts = line.split(None, 1) 901 opts = nanopb_pb2.NanoPBOptions() 902 text_format.Merge(parts[1], opts) 903 results.append((parts[0], opts)) 904 905 return results 906 907class Globals: 908 '''Ugly global variables, should find a good way to pass these.''' 909 verbose_options = False 910 separate_options = [] 911 matched_namemasks = set() 912 913def get_nanopb_suboptions(subdesc, options, name): 914 '''Get copy of options, and merge information from subdesc.''' 915 new_options = nanopb_pb2.NanoPBOptions() 916 new_options.CopyFrom(options) 917 918 # Handle options defined in a separate file 919 dotname = '.'.join(name.parts) 920 for namemask, options in Globals.separate_options: 921 if fnmatch(dotname, namemask): 922 Globals.matched_namemasks.add(namemask) 923 new_options.MergeFrom(options) 924 925 # Handle options defined in .proto 926 if isinstance(subdesc.options, descriptor.FieldOptions): 927 ext_type = nanopb_pb2.nanopb 928 elif isinstance(subdesc.options, descriptor.FileOptions): 929 ext_type = nanopb_pb2.nanopb_fileopt 930 elif isinstance(subdesc.options, descriptor.MessageOptions): 931 ext_type = nanopb_pb2.nanopb_msgopt 932 elif isinstance(subdesc.options, descriptor.EnumOptions): 933 ext_type = nanopb_pb2.nanopb_enumopt 934 else: 935 raise Exception("Unknown options type") 936 937 if subdesc.options.HasExtension(ext_type): 938 ext = subdesc.options.Extensions[ext_type] 939 new_options.MergeFrom(ext) 940 941 if Globals.verbose_options: 942 sys.stderr.write("Options for " + dotname + ": ") 943 sys.stderr.write(text_format.MessageToString(new_options) + "\n") 944 945 return new_options 946 947 948# --------------------------------------------------------------------------- 949# Command line interface 950# --------------------------------------------------------------------------- 951 952import sys 953import os.path 954from optparse import OptionParser 955 956optparser = OptionParser( 957 usage = "Usage: nanopb_generator.py [options] file.pb ...", 958 epilog = "Compile file.pb from file.proto by: 'protoc -ofile.pb file.proto'. " + 959 "Output will be written to file.pb.h and file.pb.c.") 960optparser.add_option("-x", dest="exclude", metavar="FILE", action="append", default=[], 961 help="Exclude file from generated #include list.") 962optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default="pb", 963 help="Set extension to use instead of 'pb' for generated files. [default: %default]") 964optparser.add_option("-f", "--options-file", dest="options_file", metavar="FILE", default="%s.options", 965 help="Set name of a separate generator options file.") 966optparser.add_option("-Q", "--generated-include-format", dest="genformat", 967 metavar="FORMAT", default='#include "%s"\n', 968 help="Set format string to use for including other .pb.h files. [default: %default]") 969optparser.add_option("-L", "--library-include-format", dest="libformat", 970 metavar="FORMAT", default='#include <%s>\n', 971 help="Set format string to use for including the nanopb pb.h header. [default: %default]") 972optparser.add_option("-T", "--no-timestamp", dest="notimestamp", action="store_true", default=False, 973 help="Don't add timestamp to .pb.h and .pb.c preambles") 974optparser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False, 975 help="Don't print anything except errors.") 976optparser.add_option("-v", "--verbose", dest="verbose", action="store_true", default=False, 977 help="Print more information.") 978optparser.add_option("-s", dest="settings", metavar="OPTION:VALUE", action="append", default=[], 979 help="Set generator option (max_size, max_count etc.).") 980 981def process_file(filename, fdesc, options): 982 '''Process a single file. 983 filename: The full path to the .proto or .pb source file, as string. 984 fdesc: The loaded FileDescriptorSet, or None to read from the input file. 985 options: Command line options as they come from OptionsParser. 986 987 Returns a dict: 988 {'headername': Name of header file, 989 'headerdata': Data for the .h header file, 990 'sourcename': Name of the source code file, 991 'sourcedata': Data for the .c source code file 992 } 993 ''' 994 toplevel_options = nanopb_pb2.NanoPBOptions() 995 for s in options.settings: 996 text_format.Merge(s, toplevel_options) 997 998 if not fdesc: 999 data = open(filename, 'rb').read() 1000 fdesc = descriptor.FileDescriptorSet.FromString(data).file[0] 1001 1002 # Check if there is a separate .options file 1003 try: 1004 optfilename = options.options_file % os.path.splitext(filename)[0] 1005 except TypeError: 1006 # No %s specified, use the filename as-is 1007 optfilename = options.options_file 1008 1009 if os.path.isfile(optfilename): 1010 if options.verbose: 1011 sys.stderr.write('Reading options from ' + optfilename + '\n') 1012 1013 Globals.separate_options = read_options_file(open(optfilename, "rU")) 1014 else: 1015 Globals.separate_options = [] 1016 Globals.matched_namemasks = set() 1017 1018 # Parse the file 1019 file_options = get_nanopb_suboptions(fdesc, toplevel_options, Names([filename])) 1020 enums, messages, extensions = parse_file(fdesc, file_options) 1021 1022 # Decide the file names 1023 noext = os.path.splitext(filename)[0] 1024 headername = noext + '.' + options.extension + '.h' 1025 sourcename = noext + '.' + options.extension + '.c' 1026 headerbasename = os.path.basename(headername) 1027 1028 # List of .proto files that should not be included in the C header file 1029 # even if they are mentioned in the source .proto. 1030 excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto'] + options.exclude 1031 dependencies = [d for d in fdesc.dependency if d not in excludes] 1032 1033 headerdata = ''.join(generate_header(dependencies, headerbasename, enums, 1034 messages, extensions, options)) 1035 1036 sourcedata = ''.join(generate_source(headerbasename, enums, 1037 messages, extensions, options)) 1038 1039 # Check if there were any lines in .options that did not match a member 1040 unmatched = [n for n,o in Globals.separate_options if n not in Globals.matched_namemasks] 1041 if unmatched and not options.quiet: 1042 sys.stderr.write("Following patterns in " + optfilename + " did not match any fields: " 1043 + ', '.join(unmatched) + "\n") 1044 if not Globals.verbose_options: 1045 sys.stderr.write("Use protoc --nanopb-out=-v:. to see a list of the field names.\n") 1046 1047 return {'headername': headername, 'headerdata': headerdata, 1048 'sourcename': sourcename, 'sourcedata': sourcedata} 1049 1050def main_cli(): 1051 '''Main function when invoked directly from the command line.''' 1052 1053 options, filenames = optparser.parse_args() 1054 1055 if not filenames: 1056 optparser.print_help() 1057 sys.exit(1) 1058 1059 if options.quiet: 1060 options.verbose = False 1061 1062 Globals.verbose_options = options.verbose 1063 1064 for filename in filenames: 1065 results = process_file(filename, None, options) 1066 1067 if not options.quiet: 1068 sys.stderr.write("Writing to " + results['headername'] + " and " 1069 + results['sourcename'] + "\n") 1070 1071 open(results['headername'], 'w').write(results['headerdata']) 1072 open(results['sourcename'], 'w').write(results['sourcedata']) 1073 1074def main_plugin(): 1075 '''Main function when invoked as a protoc plugin.''' 1076 1077 import sys 1078 if sys.platform == "win32": 1079 import os, msvcrt 1080 # Set stdin and stdout to binary mode 1081 msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY) 1082 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) 1083 1084 data = sys.stdin.read() 1085 request = plugin_pb2.CodeGeneratorRequest.FromString(data) 1086 1087 import shlex 1088 args = shlex.split(request.parameter) 1089 options, dummy = optparser.parse_args(args) 1090 1091 Globals.verbose_options = options.verbose 1092 1093 response = plugin_pb2.CodeGeneratorResponse() 1094 1095 for filename in request.file_to_generate: 1096 for fdesc in request.proto_file: 1097 if fdesc.name == filename: 1098 results = process_file(filename, fdesc, options) 1099 1100 f = response.file.add() 1101 f.name = results['headername'] 1102 f.content = results['headerdata'] 1103 1104 f = response.file.add() 1105 f.name = results['sourcename'] 1106 f.content = results['sourcedata'] 1107 1108 sys.stdout.write(response.SerializeToString()) 1109 1110if __name__ == '__main__': 1111 # Check if we are running as a plugin under protoc 1112 if 'protoc-gen-' in sys.argv[0] or '--protoc-plugin' in sys.argv: 1113 main_plugin() 1114 else: 1115 main_cli() 1116 1117