1#!/usr/bin/env python 2 3from __future__ import unicode_literals 4 5'''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.''' 6nanopb_version = "nanopb-0.3.7-dev" 7 8import sys 9import re 10from functools import reduce 11 12try: 13 # Add some dummy imports to keep packaging tools happy. 14 import google, distutils.util # bbfreeze seems to need these 15 import pkg_resources # pyinstaller / protobuf 2.5 seem to need these 16except: 17 # Don't care, we will error out later if it is actually important. 18 pass 19 20try: 21 import google.protobuf.text_format as text_format 22 import google.protobuf.descriptor_pb2 as descriptor 23except: 24 sys.stderr.write(''' 25 ************************************************************* 26 *** Could not import the Google protobuf Python libraries *** 27 *** Try installing package 'python-protobuf' or similar. *** 28 ************************************************************* 29 ''' + '\n') 30 raise 31 32try: 33 import proto.nanopb_pb2 as nanopb_pb2 34 import proto.plugin_pb2 as plugin_pb2 35except: 36 sys.stderr.write(''' 37 ******************************************************************** 38 *** Failed to import the protocol definitions for generator. *** 39 *** You have to run 'make' in the nanopb/generator/proto folder. *** 40 ******************************************************************** 41 ''' + '\n') 42 raise 43 44# --------------------------------------------------------------------------- 45# Generation of single fields 46# --------------------------------------------------------------------------- 47 48import time 49import os.path 50 51# Values are tuple (c type, pb type, encoded size, int_size_allowed) 52FieldD = descriptor.FieldDescriptorProto 53datatypes = { 54 FieldD.TYPE_BOOL: ('bool', 'BOOL', 1, False), 55 FieldD.TYPE_DOUBLE: ('double', 'DOUBLE', 8, False), 56 FieldD.TYPE_FIXED32: ('uint32_t', 'FIXED32', 4, False), 57 FieldD.TYPE_FIXED64: ('uint64_t', 'FIXED64', 8, False), 58 FieldD.TYPE_FLOAT: ('float', 'FLOAT', 4, False), 59 FieldD.TYPE_INT32: ('int32_t', 'INT32', 10, True), 60 FieldD.TYPE_INT64: ('int64_t', 'INT64', 10, True), 61 FieldD.TYPE_SFIXED32: ('int32_t', 'SFIXED32', 4, False), 62 FieldD.TYPE_SFIXED64: ('int64_t', 'SFIXED64', 8, False), 63 FieldD.TYPE_SINT32: ('int32_t', 'SINT32', 5, True), 64 FieldD.TYPE_SINT64: ('int64_t', 'SINT64', 10, True), 65 FieldD.TYPE_UINT32: ('uint32_t', 'UINT32', 5, True), 66 FieldD.TYPE_UINT64: ('uint64_t', 'UINT64', 10, True) 67} 68 69# Integer size overrides (from .proto settings) 70intsizes = { 71 nanopb_pb2.IS_8: 'int8_t', 72 nanopb_pb2.IS_16: 'int16_t', 73 nanopb_pb2.IS_32: 'int32_t', 74 nanopb_pb2.IS_64: 'int64_t', 75} 76 77# String types (for python 2 / python 3 compatibility) 78try: 79 strtypes = (unicode, str) 80except NameError: 81 strtypes = (str, ) 82 83class Names: 84 '''Keeps a set of nested names and formats them to C identifier.''' 85 def __init__(self, parts = ()): 86 if isinstance(parts, Names): 87 parts = parts.parts 88 self.parts = tuple(parts) 89 90 def __str__(self): 91 return '_'.join(self.parts) 92 93 def __add__(self, other): 94 if isinstance(other, strtypes): 95 return Names(self.parts + (other,)) 96 elif isinstance(other, tuple): 97 return Names(self.parts + other) 98 else: 99 raise ValueError("Name parts should be of type str") 100 101 def __eq__(self, other): 102 return isinstance(other, Names) and self.parts == other.parts 103 104def names_from_type_name(type_name): 105 '''Parse Names() from FieldDescriptorProto type_name''' 106 if type_name[0] != '.': 107 raise NotImplementedError("Lookup of non-absolute type names is not supported") 108 return Names(type_name[1:].split('.')) 109 110def varint_max_size(max_value): 111 '''Returns the maximum number of bytes a varint can take when encoded.''' 112 if max_value < 0: 113 max_value = 2**64 - max_value 114 for i in range(1, 11): 115 if (max_value >> (i * 7)) == 0: 116 return i 117 raise ValueError("Value too large for varint: " + str(max_value)) 118 119assert varint_max_size(-1) == 10 120assert varint_max_size(0) == 1 121assert varint_max_size(127) == 1 122assert varint_max_size(128) == 2 123 124class EncodedSize: 125 '''Class used to represent the encoded size of a field or a message. 126 Consists of a combination of symbolic sizes and integer sizes.''' 127 def __init__(self, value = 0, symbols = []): 128 if isinstance(value, EncodedSize): 129 self.value = value.value 130 self.symbols = value.symbols 131 elif isinstance(value, strtypes + (Names,)): 132 self.symbols = [str(value)] 133 self.value = 0 134 else: 135 self.value = value 136 self.symbols = symbols 137 138 def __add__(self, other): 139 if isinstance(other, int): 140 return EncodedSize(self.value + other, self.symbols) 141 elif isinstance(other, strtypes + (Names,)): 142 return EncodedSize(self.value, self.symbols + [str(other)]) 143 elif isinstance(other, EncodedSize): 144 return EncodedSize(self.value + other.value, self.symbols + other.symbols) 145 else: 146 raise ValueError("Cannot add size: " + repr(other)) 147 148 def __mul__(self, other): 149 if isinstance(other, int): 150 return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols]) 151 else: 152 raise ValueError("Cannot multiply size: " + repr(other)) 153 154 def __str__(self): 155 if not self.symbols: 156 return str(self.value) 157 else: 158 return '(' + str(self.value) + ' + ' + ' + '.join(self.symbols) + ')' 159 160 def upperlimit(self): 161 if not self.symbols: 162 return self.value 163 else: 164 return 2**32 - 1 165 166class Enum: 167 def __init__(self, names, desc, enum_options): 168 '''desc is EnumDescriptorProto''' 169 170 self.options = enum_options 171 self.names = names + desc.name 172 173 if enum_options.long_names: 174 self.values = [(self.names + x.name, x.number) for x in desc.value] 175 else: 176 self.values = [(names + x.name, x.number) for x in desc.value] 177 178 self.value_longnames = [self.names + x.name for x in desc.value] 179 self.packed = enum_options.packed_enum 180 181 def has_negative(self): 182 for n, v in self.values: 183 if v < 0: 184 return True 185 return False 186 187 def encoded_size(self): 188 return max([varint_max_size(v) for n,v in self.values]) 189 190 def __str__(self): 191 result = 'typedef enum _%s {\n' % self.names 192 result += ',\n'.join([" %s = %d" % x for x in self.values]) 193 result += '\n}' 194 195 if self.packed: 196 result += ' pb_packed' 197 198 result += ' %s;' % self.names 199 200 result += '\n#define _%s_MIN %s' % (self.names, self.values[0][0]) 201 result += '\n#define _%s_MAX %s' % (self.names, self.values[-1][0]) 202 result += '\n#define _%s_ARRAYSIZE ((%s)(%s+1))' % (self.names, self.names, self.values[-1][0]) 203 204 if not self.options.long_names: 205 # Define the long names always so that enum value references 206 # from other files work properly. 207 for i, x in enumerate(self.values): 208 result += '\n#define %s %s' % (self.value_longnames[i], x[0]) 209 210 return result 211 212class FieldMaxSize: 213 def __init__(self, worst = 0, checks = [], field_name = 'undefined'): 214 if isinstance(worst, list): 215 self.worst = max(i for i in worst if i is not None) 216 else: 217 self.worst = worst 218 219 self.worst_field = field_name 220 self.checks = list(checks) 221 222 def extend(self, extend, field_name = None): 223 self.worst = max(self.worst, extend.worst) 224 225 if self.worst == extend.worst: 226 self.worst_field = extend.worst_field 227 228 self.checks.extend(extend.checks) 229 230class Field: 231 def __init__(self, struct_name, desc, field_options): 232 '''desc is FieldDescriptorProto''' 233 self.tag = desc.number 234 self.struct_name = struct_name 235 self.union_name = None 236 self.name = desc.name 237 self.default = None 238 self.max_size = None 239 self.max_count = None 240 self.array_decl = "" 241 self.enc_size = None 242 self.ctype = None 243 244 self.inline = None 245 if field_options.type == nanopb_pb2.FT_INLINE: 246 field_options.type = nanopb_pb2.FT_STATIC 247 self.inline = nanopb_pb2.FT_INLINE 248 249 # Parse field options 250 if field_options.HasField("max_size"): 251 self.max_size = field_options.max_size 252 253 if field_options.HasField("max_count"): 254 self.max_count = field_options.max_count 255 256 if desc.HasField('default_value'): 257 self.default = desc.default_value 258 259 # Check field rules, i.e. required/optional/repeated. 260 can_be_static = True 261 if desc.label == FieldD.LABEL_REQUIRED: 262 self.rules = 'REQUIRED' 263 elif desc.label == FieldD.LABEL_OPTIONAL: 264 self.rules = 'OPTIONAL' 265 elif desc.label == FieldD.LABEL_REPEATED: 266 self.rules = 'REPEATED' 267 if self.max_count is None: 268 can_be_static = False 269 else: 270 self.array_decl = '[%d]' % self.max_count 271 else: 272 raise NotImplementedError(desc.label) 273 274 # Check if the field can be implemented with static allocation 275 # i.e. whether the data size is known. 276 if desc.type == FieldD.TYPE_STRING and self.max_size is None: 277 can_be_static = False 278 279 if desc.type == FieldD.TYPE_BYTES and self.max_size is None: 280 can_be_static = False 281 282 # Decide how the field data will be allocated 283 if field_options.type == nanopb_pb2.FT_DEFAULT: 284 if can_be_static: 285 field_options.type = nanopb_pb2.FT_STATIC 286 else: 287 field_options.type = nanopb_pb2.FT_CALLBACK 288 289 if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static: 290 raise Exception("Field %s is defined as static, but max_size or " 291 "max_count is not given." % self.name) 292 293 if field_options.type == nanopb_pb2.FT_STATIC: 294 self.allocation = 'STATIC' 295 elif field_options.type == nanopb_pb2.FT_POINTER: 296 self.allocation = 'POINTER' 297 elif field_options.type == nanopb_pb2.FT_CALLBACK: 298 self.allocation = 'CALLBACK' 299 else: 300 raise NotImplementedError(field_options.type) 301 302 # Decide the C data type to use in the struct. 303 if desc.type in datatypes: 304 self.ctype, self.pbtype, self.enc_size, isa = datatypes[desc.type] 305 306 # Override the field size if user wants to use smaller integers 307 if isa and field_options.int_size != nanopb_pb2.IS_DEFAULT: 308 self.ctype = intsizes[field_options.int_size] 309 if desc.type == FieldD.TYPE_UINT32 or desc.type == FieldD.TYPE_UINT64: 310 self.ctype = 'u' + self.ctype; 311 elif desc.type == FieldD.TYPE_ENUM: 312 self.pbtype = 'ENUM' 313 self.ctype = names_from_type_name(desc.type_name) 314 if self.default is not None: 315 self.default = self.ctype + self.default 316 self.enc_size = None # Needs to be filled in when enum values are known 317 elif desc.type == FieldD.TYPE_STRING: 318 self.pbtype = 'STRING' 319 self.ctype = 'char' 320 if self.allocation == 'STATIC': 321 self.ctype = 'char' 322 self.array_decl += '[%d]' % self.max_size 323 self.enc_size = varint_max_size(self.max_size) + self.max_size 324 elif desc.type == FieldD.TYPE_BYTES: 325 self.pbtype = 'BYTES' 326 if self.allocation == 'STATIC': 327 # Inline STATIC for BYTES is like STATIC for STRING. 328 if self.inline: 329 self.ctype = 'pb_byte_t' 330 self.array_decl += '[%d]' % self.max_size 331 else: 332 self.ctype = self.struct_name + self.name + 't' 333 self.enc_size = varint_max_size(self.max_size) + self.max_size 334 elif self.allocation == 'POINTER': 335 self.ctype = 'pb_bytes_array_t' 336 elif desc.type == FieldD.TYPE_MESSAGE: 337 self.pbtype = 'MESSAGE' 338 self.ctype = self.submsgname = names_from_type_name(desc.type_name) 339 self.enc_size = None # Needs to be filled in after the message type is available 340 else: 341 raise NotImplementedError(desc.type) 342 343 def __lt__(self, other): 344 return self.tag < other.tag 345 346 def __str__(self): 347 result = '' 348 if self.allocation == 'POINTER': 349 if self.rules == 'REPEATED': 350 result += ' pb_size_t ' + self.name + '_count;\n' 351 352 if self.pbtype == 'MESSAGE': 353 # Use struct definition, so recursive submessages are possible 354 result += ' struct _%s *%s;' % (self.ctype, self.name) 355 elif self.rules == 'REPEATED' and self.pbtype in ['STRING', 'BYTES']: 356 # String/bytes arrays need to be defined as pointers to pointers 357 result += ' %s **%s;' % (self.ctype, self.name) 358 else: 359 result += ' %s *%s;' % (self.ctype, self.name) 360 elif self.allocation == 'CALLBACK': 361 result += ' pb_callback_t %s;' % self.name 362 else: 363 if self.rules == 'OPTIONAL' and self.allocation == 'STATIC': 364 result += ' bool has_' + self.name + ';\n' 365 elif self.rules == 'REPEATED' and self.allocation == 'STATIC': 366 result += ' pb_size_t ' + self.name + '_count;\n' 367 result += ' %s %s%s;' % (self.ctype, self.name, self.array_decl) 368 return result 369 370 def types(self): 371 '''Return definitions for any special types this field might need.''' 372 if self.pbtype == 'BYTES' and self.allocation == 'STATIC' and not self.inline: 373 result = 'typedef PB_BYTES_ARRAY_T(%d) %s;\n' % (self.max_size, self.ctype) 374 else: 375 result = '' 376 return result 377 378 def get_dependencies(self): 379 '''Get list of type names used by this field.''' 380 if self.allocation == 'STATIC': 381 return [str(self.ctype)] 382 else: 383 return [] 384 385 def get_initializer(self, null_init, inner_init_only = False): 386 '''Return literal expression for this field's default value. 387 null_init: If True, initialize to a 0 value instead of default from .proto 388 inner_init_only: If True, exclude initialization for any count/has fields 389 ''' 390 391 inner_init = None 392 if self.pbtype == 'MESSAGE': 393 if null_init: 394 inner_init = '%s_init_zero' % self.ctype 395 else: 396 inner_init = '%s_init_default' % self.ctype 397 elif self.default is None or null_init: 398 if self.pbtype == 'STRING': 399 inner_init = '""' 400 elif self.pbtype == 'BYTES': 401 if self.inline: 402 inner_init = '{0}' 403 else: 404 inner_init = '{0, {0}}' 405 elif self.pbtype in ('ENUM', 'UENUM'): 406 inner_init = '(%s)0' % self.ctype 407 else: 408 inner_init = '0' 409 else: 410 if self.pbtype == 'STRING': 411 inner_init = self.default.replace('"', '\\"') 412 inner_init = '"' + inner_init + '"' 413 elif self.pbtype == 'BYTES': 414 data = ['0x%02x' % ord(c) for c in self.default] 415 if len(data) == 0: 416 if self.inline: 417 inner_init = '{0}' 418 else: 419 inner_init = '{0, {0}}' 420 else: 421 if self.inline: 422 inner_init = '{%s}' % ','.join(data) 423 else: 424 inner_init = '{%d, {%s}}' % (len(data), ','.join(data)) 425 elif self.pbtype in ['FIXED32', 'UINT32']: 426 inner_init = str(self.default) + 'u' 427 elif self.pbtype in ['FIXED64', 'UINT64']: 428 inner_init = str(self.default) + 'ull' 429 elif self.pbtype in ['SFIXED64', 'INT64']: 430 inner_init = str(self.default) + 'll' 431 else: 432 inner_init = str(self.default) 433 434 if inner_init_only: 435 return inner_init 436 437 outer_init = None 438 if self.allocation == 'STATIC': 439 if self.rules == 'REPEATED': 440 outer_init = '0, {' 441 outer_init += ', '.join([inner_init] * self.max_count) 442 outer_init += '}' 443 elif self.rules == 'OPTIONAL': 444 outer_init = 'false, ' + inner_init 445 else: 446 outer_init = inner_init 447 elif self.allocation == 'POINTER': 448 if self.rules == 'REPEATED': 449 outer_init = '0, NULL' 450 else: 451 outer_init = 'NULL' 452 elif self.allocation == 'CALLBACK': 453 if self.pbtype == 'EXTENSION': 454 outer_init = 'NULL' 455 else: 456 outer_init = '{{NULL}, NULL}' 457 458 return outer_init 459 460 def default_decl(self, declaration_only = False): 461 '''Return definition for this field's default value.''' 462 if self.default is None: 463 return None 464 465 ctype = self.ctype 466 default = self.get_initializer(False, True) 467 array_decl = '' 468 469 if self.pbtype == 'STRING': 470 if self.allocation != 'STATIC': 471 return None # Not implemented 472 array_decl = '[%d]' % self.max_size 473 elif self.pbtype == 'BYTES': 474 if self.allocation != 'STATIC': 475 return None # Not implemented 476 if self.inline: 477 array_decl = '[%d]' % self.max_size 478 479 if declaration_only: 480 return 'extern const %s %s_default%s;' % (ctype, self.struct_name + self.name, array_decl) 481 else: 482 return 'const %s %s_default%s = %s;' % (ctype, self.struct_name + self.name, array_decl, default) 483 484 def tags(self): 485 '''Return the #define for the tag number of this field.''' 486 identifier = '%s_%s_tag' % (self.struct_name, self.name) 487 return '#define %-40s %d\n' % (identifier, self.tag) 488 489 def pb_field_t(self, prev_field_name): 490 '''Return the pb_field_t initializer to use in the constant array. 491 prev_field_name is the name of the previous field or None. 492 ''' 493 494 if self.rules == 'ONEOF': 495 if self.anonymous: 496 result = ' PB_ANONYMOUS_ONEOF_FIELD(%s, ' % self.union_name 497 else: 498 result = ' PB_ONEOF_FIELD(%s, ' % self.union_name 499 else: 500 result = ' PB_FIELD(' 501 502 result += '%3d, ' % self.tag 503 result += '%-8s, ' % self.pbtype 504 result += '%s, ' % self.rules 505 result += '%-8s, ' % (self.allocation if not self.inline else "INLINE") 506 result += '%s, ' % ("FIRST" if not prev_field_name else "OTHER") 507 result += '%s, ' % self.struct_name 508 result += '%s, ' % self.name 509 result += '%s, ' % (prev_field_name or self.name) 510 511 if self.pbtype == 'MESSAGE': 512 result += '&%s_fields)' % self.submsgname 513 elif self.default is None: 514 result += '0)' 515 elif self.pbtype in ['BYTES', 'STRING'] and self.allocation != 'STATIC': 516 result += '0)' # Arbitrary size default values not implemented 517 elif self.rules == 'OPTEXT': 518 result += '0)' # Default value for extensions is not implemented 519 else: 520 result += '&%s_default)' % (self.struct_name + self.name) 521 522 return result 523 524 def get_last_field_name(self): 525 return self.name 526 527 def largest_field_value(self): 528 '''Determine if this field needs 16bit or 32bit pb_field_t structure to compile properly. 529 Returns numeric value or a C-expression for assert.''' 530 check = [] 531 if self.pbtype == 'MESSAGE': 532 if self.rules == 'REPEATED' and self.allocation == 'STATIC': 533 check.append('pb_membersize(%s, %s[0])' % (self.struct_name, self.name)) 534 elif self.rules == 'ONEOF': 535 if self.anonymous: 536 check.append('pb_membersize(%s, %s)' % (self.struct_name, self.name)) 537 else: 538 check.append('pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name)) 539 else: 540 check.append('pb_membersize(%s, %s)' % (self.struct_name, self.name)) 541 542 return FieldMaxSize([self.tag, self.max_size, self.max_count], 543 check, 544 ('%s.%s' % (self.struct_name, self.name))) 545 546 def encoded_size(self, dependencies): 547 '''Return the maximum size that this field can take when encoded, 548 including the field tag. If the size cannot be determined, returns 549 None.''' 550 551 if self.allocation != 'STATIC': 552 return None 553 554 if self.pbtype == 'MESSAGE': 555 encsize = None 556 if str(self.submsgname) in dependencies: 557 submsg = dependencies[str(self.submsgname)] 558 encsize = submsg.encoded_size(dependencies) 559 if encsize is not None: 560 # Include submessage length prefix 561 encsize += varint_max_size(encsize.upperlimit()) 562 563 if encsize is None: 564 # Submessage or its size cannot be found. 565 # This can occur if submessage is defined in different 566 # file, and it or its .options could not be found. 567 # Instead of direct numeric value, reference the size that 568 # has been #defined in the other file. 569 encsize = EncodedSize(self.submsgname + 'size') 570 571 # We will have to make a conservative assumption on the length 572 # prefix size, though. 573 encsize += 5 574 575 elif self.pbtype in ['ENUM', 'UENUM']: 576 if str(self.ctype) in dependencies: 577 enumtype = dependencies[str(self.ctype)] 578 encsize = enumtype.encoded_size() 579 else: 580 # Conservative assumption 581 encsize = 10 582 583 elif self.enc_size is None: 584 raise RuntimeError("Could not determine encoded size for %s.%s" 585 % (self.struct_name, self.name)) 586 else: 587 encsize = EncodedSize(self.enc_size) 588 589 encsize += varint_max_size(self.tag << 3) # Tag + wire type 590 591 if self.rules == 'REPEATED': 592 # Decoders must be always able to handle unpacked arrays. 593 # Therefore we have to reserve space for it, even though 594 # we emit packed arrays ourselves. 595 encsize *= self.max_count 596 597 return encsize 598 599 600class ExtensionRange(Field): 601 def __init__(self, struct_name, range_start, field_options): 602 '''Implements a special pb_extension_t* field in an extensible message 603 structure. The range_start signifies the index at which the extensions 604 start. Not necessarily all tags above this are extensions, it is merely 605 a speed optimization. 606 ''' 607 self.tag = range_start 608 self.struct_name = struct_name 609 self.name = 'extensions' 610 self.pbtype = 'EXTENSION' 611 self.rules = 'OPTIONAL' 612 self.allocation = 'CALLBACK' 613 self.ctype = 'pb_extension_t' 614 self.array_decl = '' 615 self.default = None 616 self.max_size = 0 617 self.max_count = 0 618 self.inline = None 619 620 def __str__(self): 621 return ' pb_extension_t *extensions;' 622 623 def types(self): 624 return '' 625 626 def tags(self): 627 return '' 628 629 def encoded_size(self, dependencies): 630 # We exclude extensions from the count, because they cannot be known 631 # until runtime. Other option would be to return None here, but this 632 # way the value remains useful if extensions are not used. 633 return EncodedSize(0) 634 635class ExtensionField(Field): 636 def __init__(self, struct_name, desc, field_options): 637 self.fullname = struct_name + desc.name 638 self.extendee_name = names_from_type_name(desc.extendee) 639 Field.__init__(self, self.fullname + 'struct', desc, field_options) 640 641 if self.rules != 'OPTIONAL': 642 self.skip = True 643 else: 644 self.skip = False 645 self.rules = 'OPTEXT' 646 647 def tags(self): 648 '''Return the #define for the tag number of this field.''' 649 identifier = '%s_tag' % self.fullname 650 return '#define %-40s %d\n' % (identifier, self.tag) 651 652 def extension_decl(self): 653 '''Declaration of the extension type in the .pb.h file''' 654 if self.skip: 655 msg = '/* Extension field %s was skipped because only "optional"\n' % self.fullname 656 msg +=' type of extension fields is currently supported. */\n' 657 return msg 658 659 return ('extern const pb_extension_type_t %s; /* field type: %s */\n' % 660 (self.fullname, str(self).strip())) 661 662 def extension_def(self): 663 '''Definition of the extension type in the .pb.c file''' 664 665 if self.skip: 666 return '' 667 668 result = 'typedef struct {\n' 669 result += str(self) 670 result += '\n} %s;\n\n' % self.struct_name 671 result += ('static const pb_field_t %s_field = \n %s;\n\n' % 672 (self.fullname, self.pb_field_t(None))) 673 result += 'const pb_extension_type_t %s = {\n' % self.fullname 674 result += ' NULL,\n' 675 result += ' NULL,\n' 676 result += ' &%s_field\n' % self.fullname 677 result += '};\n' 678 return result 679 680 681# --------------------------------------------------------------------------- 682# Generation of oneofs (unions) 683# --------------------------------------------------------------------------- 684 685class OneOf(Field): 686 def __init__(self, struct_name, oneof_desc): 687 self.struct_name = struct_name 688 self.name = oneof_desc.name 689 self.ctype = 'union' 690 self.pbtype = 'oneof' 691 self.fields = [] 692 self.allocation = 'ONEOF' 693 self.default = None 694 self.rules = 'ONEOF' 695 self.anonymous = False 696 self.inline = None 697 698 def add_field(self, field): 699 if field.allocation == 'CALLBACK': 700 raise Exception("Callback fields inside of oneof are not supported" 701 + " (field %s)" % field.name) 702 703 field.union_name = self.name 704 field.rules = 'ONEOF' 705 field.anonymous = self.anonymous 706 self.fields.append(field) 707 self.fields.sort(key = lambda f: f.tag) 708 709 # Sort by the lowest tag number inside union 710 self.tag = min([f.tag for f in self.fields]) 711 712 def __str__(self): 713 result = '' 714 if self.fields: 715 result += ' pb_size_t which_' + self.name + ";\n" 716 result += ' union {\n' 717 for f in self.fields: 718 result += ' ' + str(f).replace('\n', '\n ') + '\n' 719 if self.anonymous: 720 result += ' };' 721 else: 722 result += ' } ' + self.name + ';' 723 return result 724 725 def types(self): 726 return ''.join([f.types() for f in self.fields]) 727 728 def get_dependencies(self): 729 deps = [] 730 for f in self.fields: 731 deps += f.get_dependencies() 732 return deps 733 734 def get_initializer(self, null_init): 735 return '0, {' + self.fields[0].get_initializer(null_init) + '}' 736 737 def default_decl(self, declaration_only = False): 738 return None 739 740 def tags(self): 741 return ''.join([f.tags() for f in self.fields]) 742 743 def pb_field_t(self, prev_field_name): 744 result = ',\n'.join([f.pb_field_t(prev_field_name) for f in self.fields]) 745 return result 746 747 def get_last_field_name(self): 748 if self.anonymous: 749 return self.fields[-1].name 750 else: 751 return self.name + '.' + self.fields[-1].name 752 753 def largest_field_value(self): 754 largest = FieldMaxSize() 755 for f in self.fields: 756 largest.extend(f.largest_field_value()) 757 return largest 758 759 def encoded_size(self, dependencies): 760 '''Returns the size of the largest oneof field.''' 761 largest = EncodedSize(0) 762 for f in self.fields: 763 size = EncodedSize(f.encoded_size(dependencies)) 764 if size.value is None: 765 return None 766 elif size.symbols: 767 return None # Cannot resolve maximum of symbols 768 elif size.value > largest.value: 769 largest = size 770 771 return largest 772 773# --------------------------------------------------------------------------- 774# Generation of messages (structures) 775# --------------------------------------------------------------------------- 776 777 778class Message: 779 def __init__(self, names, desc, message_options): 780 self.name = names 781 self.fields = [] 782 self.oneofs = {} 783 no_unions = [] 784 785 if message_options.msgid: 786 self.msgid = message_options.msgid 787 788 if hasattr(desc, 'oneof_decl'): 789 for i, f in enumerate(desc.oneof_decl): 790 oneof_options = get_nanopb_suboptions(desc, message_options, self.name + f.name) 791 if oneof_options.no_unions: 792 no_unions.append(i) # No union, but add fields normally 793 elif oneof_options.type == nanopb_pb2.FT_IGNORE: 794 pass # No union and skip fields also 795 else: 796 oneof = OneOf(self.name, f) 797 if oneof_options.anonymous_oneof: 798 oneof.anonymous = True 799 self.oneofs[i] = oneof 800 self.fields.append(oneof) 801 802 for f in desc.field: 803 field_options = get_nanopb_suboptions(f, message_options, self.name + f.name) 804 if field_options.type == nanopb_pb2.FT_IGNORE: 805 continue 806 807 field = Field(self.name, f, field_options) 808 if (hasattr(f, 'oneof_index') and 809 f.HasField('oneof_index') and 810 f.oneof_index not in no_unions): 811 if f.oneof_index in self.oneofs: 812 self.oneofs[f.oneof_index].add_field(field) 813 else: 814 self.fields.append(field) 815 816 if len(desc.extension_range) > 0: 817 field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions') 818 range_start = min([r.start for r in desc.extension_range]) 819 if field_options.type != nanopb_pb2.FT_IGNORE: 820 self.fields.append(ExtensionRange(self.name, range_start, field_options)) 821 822 self.packed = message_options.packed_struct 823 self.ordered_fields = self.fields[:] 824 self.ordered_fields.sort() 825 826 def get_dependencies(self): 827 '''Get list of type names that this structure refers to.''' 828 deps = [] 829 for f in self.fields: 830 deps += f.get_dependencies() 831 return deps 832 833 def __str__(self): 834 result = 'typedef struct _%s {\n' % self.name 835 836 if not self.ordered_fields: 837 # Empty structs are not allowed in C standard. 838 # Therefore add a dummy field if an empty message occurs. 839 result += ' char dummy_field;' 840 841 result += '\n'.join([str(f) for f in self.ordered_fields]) 842 result += '\n/* @@protoc_insertion_point(struct:%s) */' % self.name 843 result += '\n}' 844 845 if self.packed: 846 result += ' pb_packed' 847 848 result += ' %s;' % self.name 849 850 if self.packed: 851 result = 'PB_PACKED_STRUCT_START\n' + result 852 result += '\nPB_PACKED_STRUCT_END' 853 854 return result 855 856 def types(self): 857 return ''.join([f.types() for f in self.fields]) 858 859 def get_initializer(self, null_init): 860 if not self.ordered_fields: 861 return '{0}' 862 863 parts = [] 864 for field in self.ordered_fields: 865 parts.append(field.get_initializer(null_init)) 866 return '{' + ', '.join(parts) + '}' 867 868 def default_decl(self, declaration_only = False): 869 result = "" 870 for field in self.fields: 871 default = field.default_decl(declaration_only) 872 if default is not None: 873 result += default + '\n' 874 return result 875 876 def count_required_fields(self): 877 '''Returns number of required fields inside this message''' 878 count = 0 879 for f in self.fields: 880 if not isinstance(f, OneOf): 881 if f.rules == 'REQUIRED': 882 count += 1 883 return count 884 885 def count_all_fields(self): 886 count = 0 887 for f in self.fields: 888 if isinstance(f, OneOf): 889 count += len(f.fields) 890 else: 891 count += 1 892 return count 893 894 def fields_declaration(self): 895 result = 'extern const pb_field_t %s_fields[%d];' % (self.name, self.count_all_fields() + 1) 896 return result 897 898 def fields_definition(self): 899 result = 'const pb_field_t %s_fields[%d] = {\n' % (self.name, self.count_all_fields() + 1) 900 901 prev = None 902 for field in self.ordered_fields: 903 result += field.pb_field_t(prev) 904 result += ',\n' 905 prev = field.get_last_field_name() 906 907 result += ' PB_LAST_FIELD\n};' 908 return result 909 910 def encoded_size(self, dependencies): 911 '''Return the maximum size that this message can take when encoded. 912 If the size cannot be determined, returns None. 913 ''' 914 size = EncodedSize(0) 915 for field in self.fields: 916 fsize = field.encoded_size(dependencies) 917 if fsize is None: 918 return None 919 size += fsize 920 921 return size 922 923 924# --------------------------------------------------------------------------- 925# Processing of entire .proto files 926# --------------------------------------------------------------------------- 927 928def iterate_messages(desc, names = Names()): 929 '''Recursively find all messages. For each, yield name, DescriptorProto.''' 930 if hasattr(desc, 'message_type'): 931 submsgs = desc.message_type 932 else: 933 submsgs = desc.nested_type 934 935 for submsg in submsgs: 936 sub_names = names + submsg.name 937 yield sub_names, submsg 938 939 for x in iterate_messages(submsg, sub_names): 940 yield x 941 942def iterate_extensions(desc, names = Names()): 943 '''Recursively find all extensions. 944 For each, yield name, FieldDescriptorProto. 945 ''' 946 for extension in desc.extension: 947 yield names, extension 948 949 for subname, subdesc in iterate_messages(desc, names): 950 for extension in subdesc.extension: 951 yield subname, extension 952 953def toposort2(data): 954 '''Topological sort. 955 From http://code.activestate.com/recipes/577413-topological-sort/ 956 This function is under the MIT license. 957 ''' 958 for k, v in list(data.items()): 959 v.discard(k) # Ignore self dependencies 960 extra_items_in_deps = reduce(set.union, list(data.values()), set()) - set(data.keys()) 961 data.update(dict([(item, set()) for item in extra_items_in_deps])) 962 while True: 963 ordered = set(item for item,dep in list(data.items()) if not dep) 964 if not ordered: 965 break 966 for item in sorted(ordered): 967 yield item 968 data = dict([(item, (dep - ordered)) for item,dep in list(data.items()) 969 if item not in ordered]) 970 assert not data, "A cyclic dependency exists amongst %r" % data 971 972def sort_dependencies(messages): 973 '''Sort a list of Messages based on dependencies.''' 974 dependencies = {} 975 message_by_name = {} 976 for message in messages: 977 dependencies[str(message.name)] = set(message.get_dependencies()) 978 message_by_name[str(message.name)] = message 979 980 for msgname in toposort2(dependencies): 981 if msgname in message_by_name: 982 yield message_by_name[msgname] 983 984def make_identifier(headername): 985 '''Make #ifndef identifier that contains uppercase A-Z and digits 0-9''' 986 result = "" 987 for c in headername.upper(): 988 if c.isalnum(): 989 result += c 990 else: 991 result += '_' 992 return result 993 994class ProtoFile: 995 def __init__(self, fdesc, file_options): 996 '''Takes a FileDescriptorProto and parses it.''' 997 self.fdesc = fdesc 998 self.file_options = file_options 999 self.dependencies = {} 1000 self.parse() 1001 1002 # Some of types used in this file probably come from the file itself. 1003 # Thus it has implicit dependency on itself. 1004 self.add_dependency(self) 1005 1006 def parse(self): 1007 self.enums = [] 1008 self.messages = [] 1009 self.extensions = [] 1010 1011 if self.fdesc.package: 1012 base_name = Names(self.fdesc.package.split('.')) 1013 else: 1014 base_name = Names() 1015 1016 for enum in self.fdesc.enum_type: 1017 enum_options = get_nanopb_suboptions(enum, self.file_options, base_name + enum.name) 1018 self.enums.append(Enum(base_name, enum, enum_options)) 1019 1020 for names, message in iterate_messages(self.fdesc, base_name): 1021 message_options = get_nanopb_suboptions(message, self.file_options, names) 1022 1023 if message_options.skip_message: 1024 continue 1025 1026 self.messages.append(Message(names, message, message_options)) 1027 for enum in message.enum_type: 1028 enum_options = get_nanopb_suboptions(enum, message_options, names + enum.name) 1029 self.enums.append(Enum(names, enum, enum_options)) 1030 1031 for names, extension in iterate_extensions(self.fdesc, base_name): 1032 field_options = get_nanopb_suboptions(extension, self.file_options, names + extension.name) 1033 if field_options.type != nanopb_pb2.FT_IGNORE: 1034 self.extensions.append(ExtensionField(names, extension, field_options)) 1035 1036 def add_dependency(self, other): 1037 for enum in other.enums: 1038 self.dependencies[str(enum.names)] = enum 1039 1040 for msg in other.messages: 1041 self.dependencies[str(msg.name)] = msg 1042 1043 # Fix field default values where enum short names are used. 1044 for enum in other.enums: 1045 if not enum.options.long_names: 1046 for message in self.messages: 1047 for field in message.fields: 1048 if field.default in enum.value_longnames: 1049 idx = enum.value_longnames.index(field.default) 1050 field.default = enum.values[idx][0] 1051 1052 # Fix field data types where enums have negative values. 1053 for enum in other.enums: 1054 if not enum.has_negative(): 1055 for message in self.messages: 1056 for field in message.fields: 1057 if field.pbtype == 'ENUM' and field.ctype == enum.names: 1058 field.pbtype = 'UENUM' 1059 1060 def generate_header(self, includes, headername, options): 1061 '''Generate content for a header file. 1062 Generates strings, which should be concatenated and stored to file. 1063 ''' 1064 1065 yield '/* Automatically generated nanopb header */\n' 1066 if options.notimestamp: 1067 yield '/* Generated by %s */\n\n' % (nanopb_version) 1068 else: 1069 yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) 1070 1071 if self.fdesc.package: 1072 symbol = make_identifier(self.fdesc.package + '_' + headername) 1073 else: 1074 symbol = make_identifier(headername) 1075 yield '#ifndef PB_%s_INCLUDED\n' % symbol 1076 yield '#define PB_%s_INCLUDED\n' % symbol 1077 try: 1078 yield options.libformat % ('pb.h') 1079 except TypeError: 1080 # no %s specified - use whatever was passed in as options.libformat 1081 yield options.libformat 1082 yield '\n' 1083 1084 for incfile in includes: 1085 noext = os.path.splitext(incfile)[0] 1086 yield options.genformat % (noext + options.extension + '.h') 1087 yield '\n' 1088 1089 yield '/* @@protoc_insertion_point(includes) */\n' 1090 1091 yield '#if PB_PROTO_HEADER_VERSION != 30\n' 1092 yield '#error Regenerate this file with the current version of nanopb generator.\n' 1093 yield '#endif\n' 1094 yield '\n' 1095 1096 yield '#ifdef __cplusplus\n' 1097 yield 'extern "C" {\n' 1098 yield '#endif\n\n' 1099 1100 if self.enums: 1101 yield '/* Enum definitions */\n' 1102 for enum in self.enums: 1103 yield str(enum) + '\n\n' 1104 1105 if self.messages: 1106 yield '/* Struct definitions */\n' 1107 for msg in sort_dependencies(self.messages): 1108 yield msg.types() 1109 yield str(msg) + '\n\n' 1110 1111 if self.extensions: 1112 yield '/* Extensions */\n' 1113 for extension in self.extensions: 1114 yield extension.extension_decl() 1115 yield '\n' 1116 1117 if self.messages: 1118 yield '/* Default values for struct fields */\n' 1119 for msg in self.messages: 1120 yield msg.default_decl(True) 1121 yield '\n' 1122 1123 yield '/* Initializer values for message structs */\n' 1124 for msg in self.messages: 1125 identifier = '%s_init_default' % msg.name 1126 yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False)) 1127 for msg in self.messages: 1128 identifier = '%s_init_zero' % msg.name 1129 yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True)) 1130 yield '\n' 1131 1132 yield '/* Field tags (for use in manual encoding/decoding) */\n' 1133 for msg in sort_dependencies(self.messages): 1134 for field in msg.fields: 1135 yield field.tags() 1136 for extension in self.extensions: 1137 yield extension.tags() 1138 yield '\n' 1139 1140 yield '/* Struct field encoding specification for nanopb */\n' 1141 for msg in self.messages: 1142 yield msg.fields_declaration() + '\n' 1143 yield '\n' 1144 1145 yield '/* Maximum encoded size of messages (where known) */\n' 1146 for msg in self.messages: 1147 msize = msg.encoded_size(self.dependencies) 1148 identifier = '%s_size' % msg.name 1149 if msize is not None: 1150 yield '#define %-40s %s\n' % (identifier, msize) 1151 else: 1152 yield '/* %s depends on runtime parameters */\n' % identifier 1153 yield '\n' 1154 1155 yield '/* Message IDs (where set with "msgid" option) */\n' 1156 1157 yield '#ifdef PB_MSGID\n' 1158 for msg in self.messages: 1159 if hasattr(msg,'msgid'): 1160 yield '#define PB_MSG_%d %s\n' % (msg.msgid, msg.name) 1161 yield '\n' 1162 1163 symbol = make_identifier(headername.split('.')[0]) 1164 yield '#define %s_MESSAGES \\\n' % symbol 1165 1166 for msg in self.messages: 1167 m = "-1" 1168 msize = msg.encoded_size(self.dependencies) 1169 if msize is not None: 1170 m = msize 1171 if hasattr(msg,'msgid'): 1172 yield '\tPB_MSG(%d,%s,%s) \\\n' % (msg.msgid, m, msg.name) 1173 yield '\n' 1174 1175 for msg in self.messages: 1176 if hasattr(msg,'msgid'): 1177 yield '#define %s_msgid %d\n' % (msg.name, msg.msgid) 1178 yield '\n' 1179 1180 yield '#endif\n\n' 1181 1182 yield '#ifdef __cplusplus\n' 1183 yield '} /* extern "C" */\n' 1184 yield '#endif\n' 1185 1186 # End of header 1187 yield '/* @@protoc_insertion_point(eof) */\n' 1188 yield '\n#endif\n' 1189 1190 def generate_source(self, headername, options): 1191 '''Generate content for a source file.''' 1192 1193 yield '/* Automatically generated nanopb constant definitions */\n' 1194 if options.notimestamp: 1195 yield '/* Generated by %s */\n\n' % (nanopb_version) 1196 else: 1197 yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) 1198 yield options.genformat % (headername) 1199 yield '\n' 1200 yield '/* @@protoc_insertion_point(includes) */\n' 1201 1202 yield '#if PB_PROTO_HEADER_VERSION != 30\n' 1203 yield '#error Regenerate this file with the current version of nanopb generator.\n' 1204 yield '#endif\n' 1205 yield '\n' 1206 1207 for msg in self.messages: 1208 yield msg.default_decl(False) 1209 1210 yield '\n\n' 1211 1212 for msg in self.messages: 1213 yield msg.fields_definition() + '\n\n' 1214 1215 for ext in self.extensions: 1216 yield ext.extension_def() + '\n' 1217 1218 # Add checks for numeric limits 1219 if self.messages: 1220 largest_msg = max(self.messages, key = lambda m: m.count_required_fields()) 1221 largest_count = largest_msg.count_required_fields() 1222 if largest_count > 64: 1223 yield '\n/* Check that missing required fields will be properly detected */\n' 1224 yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count 1225 yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name 1226 yield ' setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count 1227 yield '#endif\n' 1228 1229 max_field = FieldMaxSize() 1230 checks_msgnames = [] 1231 for msg in self.messages: 1232 checks_msgnames.append(msg.name) 1233 for field in msg.fields: 1234 max_field.extend(field.largest_field_value()) 1235 1236 worst = max_field.worst 1237 worst_field = max_field.worst_field 1238 checks = max_field.checks 1239 1240 if worst > 255 or checks: 1241 yield '\n/* Check that field information fits in pb_field_t */\n' 1242 1243 if worst > 65535 or checks: 1244 yield '#if !defined(PB_FIELD_32BIT)\n' 1245 if worst > 65535: 1246 yield '#error Field descriptor for %s is too large. Define PB_FIELD_32BIT to fix this.\n' % worst_field 1247 else: 1248 assertion = ' && '.join(str(c) + ' < 65536' for c in checks) 1249 msgs = '_'.join(str(n) for n in checks_msgnames) 1250 yield '/* If you get an error here, it means that you need to define PB_FIELD_32BIT\n' 1251 yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n' 1252 yield ' * \n' 1253 yield ' * The reason you need to do this is that some of your messages contain tag\n' 1254 yield ' * numbers or field sizes that are larger than what can fit in 8 or 16 bit\n' 1255 yield ' * field descriptors.\n' 1256 yield ' */\n' 1257 yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs) 1258 yield '#endif\n\n' 1259 1260 if worst < 65536: 1261 yield '#if !defined(PB_FIELD_16BIT) && !defined(PB_FIELD_32BIT)\n' 1262 if worst > 255: 1263 yield '#error Field descriptor for %s is too large. Define PB_FIELD_16BIT to fix this.\n' % worst_field 1264 else: 1265 assertion = ' && '.join(str(c) + ' < 256' for c in checks) 1266 msgs = '_'.join(str(n) for n in checks_msgnames) 1267 yield '/* If you get an error here, it means that you need to define PB_FIELD_16BIT\n' 1268 yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n' 1269 yield ' * \n' 1270 yield ' * The reason you need to do this is that some of your messages contain tag\n' 1271 yield ' * numbers or field sizes that are larger than what can fit in the default\n' 1272 yield ' * 8 bit descriptors.\n' 1273 yield ' */\n' 1274 yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs) 1275 yield '#endif\n\n' 1276 1277 # Add check for sizeof(double) 1278 has_double = False 1279 for msg in self.messages: 1280 for field in msg.fields: 1281 if field.ctype == 'double': 1282 has_double = True 1283 1284 if has_double: 1285 yield '\n' 1286 yield '/* On some platforms (such as AVR), double is really float.\n' 1287 yield ' * These are not directly supported by nanopb, but see example_avr_double.\n' 1288 yield ' * To get rid of this error, remove any double fields from your .proto.\n' 1289 yield ' */\n' 1290 yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n' 1291 1292 yield '\n' 1293 yield '/* @@protoc_insertion_point(eof) */\n' 1294 1295# --------------------------------------------------------------------------- 1296# Options parsing for the .proto files 1297# --------------------------------------------------------------------------- 1298 1299from fnmatch import fnmatch 1300 1301def read_options_file(infile): 1302 '''Parse a separate options file to list: 1303 [(namemask, options), ...] 1304 ''' 1305 results = [] 1306 data = infile.read() 1307 data = re.sub('/\*.*?\*/', '', data, flags = re.MULTILINE) 1308 data = re.sub('//.*?$', '', data, flags = re.MULTILINE) 1309 data = re.sub('#.*?$', '', data, flags = re.MULTILINE) 1310 for i, line in enumerate(data.split('\n')): 1311 line = line.strip() 1312 if not line: 1313 continue 1314 1315 parts = line.split(None, 1) 1316 1317 if len(parts) < 2: 1318 sys.stderr.write("%s:%d: " % (infile.name, i + 1) + 1319 "Option lines should have space between field name and options. " + 1320 "Skipping line: '%s'\n" % line) 1321 continue 1322 1323 opts = nanopb_pb2.NanoPBOptions() 1324 1325 try: 1326 text_format.Merge(parts[1], opts) 1327 except Exception as e: 1328 sys.stderr.write("%s:%d: " % (infile.name, i + 1) + 1329 "Unparseable option line: '%s'. " % line + 1330 "Error: %s\n" % str(e)) 1331 continue 1332 results.append((parts[0], opts)) 1333 1334 return results 1335 1336class Globals: 1337 '''Ugly global variables, should find a good way to pass these.''' 1338 verbose_options = False 1339 separate_options = [] 1340 matched_namemasks = set() 1341 1342def get_nanopb_suboptions(subdesc, options, name): 1343 '''Get copy of options, and merge information from subdesc.''' 1344 new_options = nanopb_pb2.NanoPBOptions() 1345 new_options.CopyFrom(options) 1346 1347 # Handle options defined in a separate file 1348 dotname = '.'.join(name.parts) 1349 for namemask, options in Globals.separate_options: 1350 if fnmatch(dotname, namemask): 1351 Globals.matched_namemasks.add(namemask) 1352 new_options.MergeFrom(options) 1353 1354 # Handle options defined in .proto 1355 if isinstance(subdesc.options, descriptor.FieldOptions): 1356 ext_type = nanopb_pb2.nanopb 1357 elif isinstance(subdesc.options, descriptor.FileOptions): 1358 ext_type = nanopb_pb2.nanopb_fileopt 1359 elif isinstance(subdesc.options, descriptor.MessageOptions): 1360 ext_type = nanopb_pb2.nanopb_msgopt 1361 elif isinstance(subdesc.options, descriptor.EnumOptions): 1362 ext_type = nanopb_pb2.nanopb_enumopt 1363 else: 1364 raise Exception("Unknown options type") 1365 1366 if subdesc.options.HasExtension(ext_type): 1367 ext = subdesc.options.Extensions[ext_type] 1368 new_options.MergeFrom(ext) 1369 1370 if Globals.verbose_options: 1371 sys.stderr.write("Options for " + dotname + ": ") 1372 sys.stderr.write(text_format.MessageToString(new_options) + "\n") 1373 1374 return new_options 1375 1376 1377# --------------------------------------------------------------------------- 1378# Command line interface 1379# --------------------------------------------------------------------------- 1380 1381import sys 1382import os.path 1383from optparse import OptionParser 1384 1385optparser = OptionParser( 1386 usage = "Usage: nanopb_generator.py [options] file.pb ...", 1387 epilog = "Compile file.pb from file.proto by: 'protoc -ofile.pb file.proto'. " + 1388 "Output will be written to file.pb.h and file.pb.c.") 1389optparser.add_option("-x", dest="exclude", metavar="FILE", action="append", default=[], 1390 help="Exclude file from generated #include list.") 1391optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default=".pb", 1392 help="Set extension to use instead of '.pb' for generated files. [default: %default]") 1393optparser.add_option("-f", "--options-file", dest="options_file", metavar="FILE", default="%s.options", 1394 help="Set name of a separate generator options file.") 1395optparser.add_option("-I", "--options-path", dest="options_path", metavar="DIR", 1396 action="append", default = [], 1397 help="Search for .options files additionally in this path") 1398optparser.add_option("-D", "--output-dir", dest="output_dir", 1399 metavar="OUTPUTDIR", default=None, 1400 help="Output directory of .pb.h and .pb.c files") 1401optparser.add_option("-Q", "--generated-include-format", dest="genformat", 1402 metavar="FORMAT", default='#include "%s"\n', 1403 help="Set format string to use for including other .pb.h files. [default: %default]") 1404optparser.add_option("-L", "--library-include-format", dest="libformat", 1405 metavar="FORMAT", default='#include <%s>\n', 1406 help="Set format string to use for including the nanopb pb.h header. [default: %default]") 1407optparser.add_option("-T", "--no-timestamp", dest="notimestamp", action="store_true", default=False, 1408 help="Don't add timestamp to .pb.h and .pb.c preambles") 1409optparser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False, 1410 help="Don't print anything except errors.") 1411optparser.add_option("-v", "--verbose", dest="verbose", action="store_true", default=False, 1412 help="Print more information.") 1413optparser.add_option("-s", dest="settings", metavar="OPTION:VALUE", action="append", default=[], 1414 help="Set generator option (max_size, max_count etc.).") 1415 1416def parse_file(filename, fdesc, options): 1417 '''Parse a single file. Returns a ProtoFile instance.''' 1418 toplevel_options = nanopb_pb2.NanoPBOptions() 1419 for s in options.settings: 1420 text_format.Merge(s, toplevel_options) 1421 1422 if not fdesc: 1423 data = open(filename, 'rb').read() 1424 fdesc = descriptor.FileDescriptorSet.FromString(data).file[0] 1425 1426 # Check if there is a separate .options file 1427 had_abspath = False 1428 try: 1429 optfilename = options.options_file % os.path.splitext(filename)[0] 1430 except TypeError: 1431 # No %s specified, use the filename as-is 1432 optfilename = options.options_file 1433 had_abspath = True 1434 1435 paths = ['.'] + options.options_path 1436 for p in paths: 1437 if os.path.isfile(os.path.join(p, optfilename)): 1438 optfilename = os.path.join(p, optfilename) 1439 if options.verbose: 1440 sys.stderr.write('Reading options from ' + optfilename + '\n') 1441 Globals.separate_options = read_options_file(open(optfilename, "rU")) 1442 break 1443 else: 1444 # If we are given a full filename and it does not exist, give an error. 1445 # However, don't give error when we automatically look for .options file 1446 # with the same name as .proto. 1447 if options.verbose or had_abspath: 1448 sys.stderr.write('Options file not found: ' + optfilename + '\n') 1449 Globals.separate_options = [] 1450 1451 Globals.matched_namemasks = set() 1452 1453 # Parse the file 1454 file_options = get_nanopb_suboptions(fdesc, toplevel_options, Names([filename])) 1455 f = ProtoFile(fdesc, file_options) 1456 f.optfilename = optfilename 1457 1458 return f 1459 1460def process_file(filename, fdesc, options, other_files = {}): 1461 '''Process a single file. 1462 filename: The full path to the .proto or .pb source file, as string. 1463 fdesc: The loaded FileDescriptorSet, or None to read from the input file. 1464 options: Command line options as they come from OptionsParser. 1465 1466 Returns a dict: 1467 {'headername': Name of header file, 1468 'headerdata': Data for the .h header file, 1469 'sourcename': Name of the source code file, 1470 'sourcedata': Data for the .c source code file 1471 } 1472 ''' 1473 f = parse_file(filename, fdesc, options) 1474 1475 # Provide dependencies if available 1476 for dep in f.fdesc.dependency: 1477 if dep in other_files: 1478 f.add_dependency(other_files[dep]) 1479 1480 # Decide the file names 1481 noext = os.path.splitext(filename)[0] 1482 headername = noext + options.extension + '.h' 1483 sourcename = noext + options.extension + '.c' 1484 headerbasename = os.path.basename(headername) 1485 1486 # List of .proto files that should not be included in the C header file 1487 # even if they are mentioned in the source .proto. 1488 excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto'] + options.exclude 1489 includes = [d for d in f.fdesc.dependency if d not in excludes] 1490 1491 headerdata = ''.join(f.generate_header(includes, headerbasename, options)) 1492 sourcedata = ''.join(f.generate_source(headerbasename, options)) 1493 1494 # Check if there were any lines in .options that did not match a member 1495 unmatched = [n for n,o in Globals.separate_options if n not in Globals.matched_namemasks] 1496 if unmatched and not options.quiet: 1497 sys.stderr.write("Following patterns in " + f.optfilename + " did not match any fields: " 1498 + ', '.join(unmatched) + "\n") 1499 if not Globals.verbose_options: 1500 sys.stderr.write("Use protoc --nanopb-out=-v:. to see a list of the field names.\n") 1501 1502 return {'headername': headername, 'headerdata': headerdata, 1503 'sourcename': sourcename, 'sourcedata': sourcedata} 1504 1505def main_cli(): 1506 '''Main function when invoked directly from the command line.''' 1507 1508 options, filenames = optparser.parse_args() 1509 1510 if not filenames: 1511 optparser.print_help() 1512 sys.exit(1) 1513 1514 if options.quiet: 1515 options.verbose = False 1516 1517 if options.output_dir and not os.path.exists(options.output_dir): 1518 optparser.print_help() 1519 sys.stderr.write("\noutput_dir does not exist: %s\n" % options.output_dir) 1520 sys.exit(1) 1521 1522 1523 Globals.verbose_options = options.verbose 1524 for filename in filenames: 1525 results = process_file(filename, None, options) 1526 1527 base_dir = options.output_dir or '' 1528 to_write = [ 1529 (os.path.join(base_dir, results['headername']), results['headerdata']), 1530 (os.path.join(base_dir, results['sourcename']), results['sourcedata']), 1531 ] 1532 1533 if not options.quiet: 1534 paths = " and ".join([x[0] for x in to_write]) 1535 sys.stderr.write("Writing to %s\n" % paths) 1536 1537 for path, data in to_write: 1538 with open(path, 'w') as f: 1539 f.write(data) 1540 1541def main_plugin(): 1542 '''Main function when invoked as a protoc plugin.''' 1543 1544 import io, sys 1545 if sys.platform == "win32": 1546 import os, msvcrt 1547 # Set stdin and stdout to binary mode 1548 msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY) 1549 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) 1550 1551 data = io.open(sys.stdin.fileno(), "rb").read() 1552 1553 request = plugin_pb2.CodeGeneratorRequest.FromString(data) 1554 1555 try: 1556 # Versions of Python prior to 2.7.3 do not support unicode 1557 # input to shlex.split(). Try to convert to str if possible. 1558 params = str(request.parameter) 1559 except UnicodeEncodeError: 1560 params = request.parameter 1561 1562 import shlex 1563 args = shlex.split(params) 1564 options, dummy = optparser.parse_args(args) 1565 1566 Globals.verbose_options = options.verbose 1567 1568 response = plugin_pb2.CodeGeneratorResponse() 1569 1570 # Google's protoc does not currently indicate the full path of proto files. 1571 # Instead always add the main file path to the search dirs, that works for 1572 # the common case. 1573 import os.path 1574 options.options_path.append(os.path.dirname(request.file_to_generate[0])) 1575 1576 # Process any include files first, in order to have them 1577 # available as dependencies 1578 other_files = {} 1579 for fdesc in request.proto_file: 1580 other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options) 1581 1582 for filename in request.file_to_generate: 1583 for fdesc in request.proto_file: 1584 if fdesc.name == filename: 1585 results = process_file(filename, fdesc, options, other_files) 1586 1587 f = response.file.add() 1588 f.name = results['headername'] 1589 f.content = results['headerdata'] 1590 1591 f = response.file.add() 1592 f.name = results['sourcename'] 1593 f.content = results['sourcedata'] 1594 1595 io.open(sys.stdout.fileno(), "wb").write(response.SerializeToString()) 1596 1597if __name__ == '__main__': 1598 # Check if we are running as a plugin under protoc 1599 if 'protoc-gen-' in sys.argv[0] or '--protoc-plugin' in sys.argv: 1600 main_plugin() 1601 else: 1602 main_cli() 1603