1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# 4# Use of this source code is governed by a BSD-style 5# license that can be found in the LICENSE file or at 6# https://developers.google.com/open-source/licenses/bsd 7 8"""Code for decoding protocol buffer primitives. 9 10This code is very similar to encoder.py -- read the docs for that module first. 11 12A "decoder" is a function with the signature: 13 Decode(buffer, pos, end, message, field_dict) 14The arguments are: 15 buffer: The string containing the encoded message. 16 pos: The current position in the string. 17 end: The position in the string where the current message ends. May be 18 less than len(buffer) if we're reading a sub-message. 19 message: The message object into which we're parsing. 20 field_dict: message._fields (avoids a hashtable lookup). 21The decoder reads the field and stores it into field_dict, returning the new 22buffer position. A decoder for a repeated field may proactively decode all of 23the elements of that field, if they appear consecutively. 24 25Note that decoders may throw any of the following: 26 IndexError: Indicates a truncated message. 27 struct.error: Unpacking of a fixed-width field failed. 28 message.DecodeError: Other errors. 29 30Decoders are expected to raise an exception if they are called with pos > end. 31This allows callers to be lax about bounds checking: it's fineto read past 32"end" as long as you are sure that someone else will notice and throw an 33exception later on. 34 35Something up the call stack is expected to catch IndexError and struct.error 36and convert them to message.DecodeError. 37 38Decoders are constructed using decoder constructors with the signature: 39 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 40The arguments are: 41 field_number: The field number of the field we want to decode. 42 is_repeated: Is the field a repeated field? (bool) 43 is_packed: Is the field a packed field? (bool) 44 key: The key to use when looking up the field within field_dict. 45 (This is actually the FieldDescriptor but nothing in this 46 file should depend on that.) 47 new_default: A function which takes a message object as a parameter and 48 returns a new instance of the default value for this field. 49 (This is called for repeated fields and sub-messages, when an 50 instance does not already exist.) 51 52As with encoders, we define a decoder constructor for every type of field. 53Then, for every field of every message class we construct an actual decoder. 54That decoder goes into a dict indexed by tag, so when we decode a message 55we repeatedly read a tag, look up the corresponding decoder, and invoke it. 56""" 57 58__author__ = 'kenton@google.com (Kenton Varda)' 59 60import math 61import struct 62 63from google.protobuf import message 64from google.protobuf.internal import containers 65from google.protobuf.internal import encoder 66from google.protobuf.internal import wire_format 67 68 69# This is not for optimization, but rather to avoid conflicts with local 70# variables named "message". 71_DecodeError = message.DecodeError 72 73 74def _VarintDecoder(mask, result_type): 75 """Return an encoder for a basic varint value (does not include tag). 76 77 Decoded values will be bitwise-anded with the given mask before being 78 returned, e.g. to limit them to 32 bits. The returned decoder does not 79 take the usual "end" parameter -- the caller is expected to do bounds checking 80 after the fact (often the caller can defer such checking until later). The 81 decoder returns a (value, new_pos) pair. 82 """ 83 84 def DecodeVarint(buffer, pos: int=None): 85 result = 0 86 shift = 0 87 while 1: 88 if pos is None: 89 # Read from BytesIO 90 try: 91 b = buffer.read(1)[0] 92 except IndexError as e: 93 if shift == 0: 94 # End of BytesIO. 95 return None 96 else: 97 raise ValueError('Fail to read varint %s' % str(e)) 98 else: 99 b = buffer[pos] 100 pos += 1 101 result |= ((b & 0x7f) << shift) 102 if not (b & 0x80): 103 result &= mask 104 result = result_type(result) 105 return result if pos is None else (result, pos) 106 shift += 7 107 if shift >= 64: 108 raise _DecodeError('Too many bytes when decoding varint.') 109 110 return DecodeVarint 111 112 113def _SignedVarintDecoder(bits, result_type): 114 """Like _VarintDecoder() but decodes signed values.""" 115 116 signbit = 1 << (bits - 1) 117 mask = (1 << bits) - 1 118 119 def DecodeVarint(buffer, pos): 120 result = 0 121 shift = 0 122 while 1: 123 b = buffer[pos] 124 result |= ((b & 0x7f) << shift) 125 pos += 1 126 if not (b & 0x80): 127 result &= mask 128 result = (result ^ signbit) - signbit 129 result = result_type(result) 130 return (result, pos) 131 shift += 7 132 if shift >= 64: 133 raise _DecodeError('Too many bytes when decoding varint.') 134 return DecodeVarint 135 136# All 32-bit and 64-bit values are represented as int. 137_DecodeVarint = _VarintDecoder((1 << 64) - 1, int) 138_DecodeSignedVarint = _SignedVarintDecoder(64, int) 139 140# Use these versions for values which must be limited to 32 bits. 141_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 142_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 143 144 145def ReadTag(buffer, pos): 146 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 147 148 We return the raw bytes of the tag rather than decoding them. The raw 149 bytes can then be used to look up the proper decoder. This effectively allows 150 us to trade some work that would be done in pure-python (decoding a varint) 151 for work that is done in C (searching for a byte string in a hash table). 152 In a low-level language it would be much cheaper to decode the varint and 153 use that, but not in Python. 154 155 Args: 156 buffer: memoryview object of the encoded bytes 157 pos: int of the current position to start from 158 159 Returns: 160 Tuple[bytes, int] of the tag data and new position. 161 """ 162 start = pos 163 while buffer[pos] & 0x80: 164 pos += 1 165 pos += 1 166 167 tag_bytes = buffer[start:pos].tobytes() 168 return tag_bytes, pos 169 170 171# -------------------------------------------------------------------- 172 173 174def _SimpleDecoder(wire_type, decode_value): 175 """Return a constructor for a decoder for fields of a particular type. 176 177 Args: 178 wire_type: The field's wire type. 179 decode_value: A function which decodes an individual value, e.g. 180 _DecodeVarint() 181 """ 182 183 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, 184 clear_if_default=False): 185 if is_packed: 186 local_DecodeVarint = _DecodeVarint 187 def DecodePackedField(buffer, pos, end, message, field_dict): 188 value = field_dict.get(key) 189 if value is None: 190 value = field_dict.setdefault(key, new_default(message)) 191 (endpoint, pos) = local_DecodeVarint(buffer, pos) 192 endpoint += pos 193 if endpoint > end: 194 raise _DecodeError('Truncated message.') 195 while pos < endpoint: 196 (element, pos) = decode_value(buffer, pos) 197 value.append(element) 198 if pos > endpoint: 199 del value[-1] # Discard corrupt value. 200 raise _DecodeError('Packed element was truncated.') 201 return pos 202 return DecodePackedField 203 elif is_repeated: 204 tag_bytes = encoder.TagBytes(field_number, wire_type) 205 tag_len = len(tag_bytes) 206 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 207 value = field_dict.get(key) 208 if value is None: 209 value = field_dict.setdefault(key, new_default(message)) 210 while 1: 211 (element, new_pos) = decode_value(buffer, pos) 212 value.append(element) 213 # Predict that the next tag is another copy of the same repeated 214 # field. 215 pos = new_pos + tag_len 216 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 217 # Prediction failed. Return. 218 if new_pos > end: 219 raise _DecodeError('Truncated message.') 220 return new_pos 221 return DecodeRepeatedField 222 else: 223 def DecodeField(buffer, pos, end, message, field_dict): 224 (new_value, pos) = decode_value(buffer, pos) 225 if pos > end: 226 raise _DecodeError('Truncated message.') 227 if clear_if_default and not new_value: 228 field_dict.pop(key, None) 229 else: 230 field_dict[key] = new_value 231 return pos 232 return DecodeField 233 234 return SpecificDecoder 235 236 237def _ModifiedDecoder(wire_type, decode_value, modify_value): 238 """Like SimpleDecoder but additionally invokes modify_value on every value 239 before storing it. Usually modify_value is ZigZagDecode. 240 """ 241 242 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 243 # not enough to make a significant difference. 244 245 def InnerDecode(buffer, pos): 246 (result, new_pos) = decode_value(buffer, pos) 247 return (modify_value(result), new_pos) 248 return _SimpleDecoder(wire_type, InnerDecode) 249 250 251def _StructPackDecoder(wire_type, format): 252 """Return a constructor for a decoder for a fixed-width field. 253 254 Args: 255 wire_type: The field's wire type. 256 format: The format string to pass to struct.unpack(). 257 """ 258 259 value_size = struct.calcsize(format) 260 local_unpack = struct.unpack 261 262 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 263 # not enough to make a significant difference. 264 265 # Note that we expect someone up-stack to catch struct.error and convert 266 # it to _DecodeError -- this way we don't have to set up exception- 267 # handling blocks every time we parse one value. 268 269 def InnerDecode(buffer, pos): 270 new_pos = pos + value_size 271 result = local_unpack(format, buffer[pos:new_pos])[0] 272 return (result, new_pos) 273 return _SimpleDecoder(wire_type, InnerDecode) 274 275 276def _FloatDecoder(): 277 """Returns a decoder for a float field. 278 279 This code works around a bug in struct.unpack for non-finite 32-bit 280 floating-point values. 281 """ 282 283 local_unpack = struct.unpack 284 285 def InnerDecode(buffer, pos): 286 """Decode serialized float to a float and new position. 287 288 Args: 289 buffer: memoryview of the serialized bytes 290 pos: int, position in the memory view to start at. 291 292 Returns: 293 Tuple[float, int] of the deserialized float value and new position 294 in the serialized data. 295 """ 296 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 297 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 298 new_pos = pos + 4 299 float_bytes = buffer[pos:new_pos].tobytes() 300 301 # If this value has all its exponent bits set, then it's non-finite. 302 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 303 # To avoid that, we parse it specially. 304 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 305 # If at least one significand bit is set... 306 if float_bytes[0:3] != b'\x00\x00\x80': 307 return (math.nan, new_pos) 308 # If sign bit is set... 309 if float_bytes[3:4] == b'\xFF': 310 return (-math.inf, new_pos) 311 return (math.inf, new_pos) 312 313 # Note that we expect someone up-stack to catch struct.error and convert 314 # it to _DecodeError -- this way we don't have to set up exception- 315 # handling blocks every time we parse one value. 316 result = local_unpack('<f', float_bytes)[0] 317 return (result, new_pos) 318 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 319 320 321def _DoubleDecoder(): 322 """Returns a decoder for a double field. 323 324 This code works around a bug in struct.unpack for not-a-number. 325 """ 326 327 local_unpack = struct.unpack 328 329 def InnerDecode(buffer, pos): 330 """Decode serialized double to a double and new position. 331 332 Args: 333 buffer: memoryview of the serialized bytes. 334 pos: int, position in the memory view to start at. 335 336 Returns: 337 Tuple[float, int] of the decoded double value and new position 338 in the serialized data. 339 """ 340 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 341 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 342 new_pos = pos + 8 343 double_bytes = buffer[pos:new_pos].tobytes() 344 345 # If this value has all its exponent bits set and at least one significand 346 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 347 # as inf or -inf. To avoid that, we treat it specially. 348 if ((double_bytes[7:8] in b'\x7F\xFF') 349 and (double_bytes[6:7] >= b'\xF0') 350 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 351 return (math.nan, new_pos) 352 353 # Note that we expect someone up-stack to catch struct.error and convert 354 # it to _DecodeError -- this way we don't have to set up exception- 355 # handling blocks every time we parse one value. 356 result = local_unpack('<d', double_bytes)[0] 357 return (result, new_pos) 358 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 359 360 361def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, 362 clear_if_default=False): 363 """Returns a decoder for enum field.""" 364 enum_type = key.enum_type 365 if is_packed: 366 local_DecodeVarint = _DecodeVarint 367 def DecodePackedField(buffer, pos, end, message, field_dict): 368 """Decode serialized packed enum to its value and a new position. 369 370 Args: 371 buffer: memoryview of the serialized bytes. 372 pos: int, position in the memory view to start at. 373 end: int, end position of serialized data 374 message: Message object to store unknown fields in 375 field_dict: Map[Descriptor, Any] to store decoded values in. 376 377 Returns: 378 int, new position in serialized data. 379 """ 380 value = field_dict.get(key) 381 if value is None: 382 value = field_dict.setdefault(key, new_default(message)) 383 (endpoint, pos) = local_DecodeVarint(buffer, pos) 384 endpoint += pos 385 if endpoint > end: 386 raise _DecodeError('Truncated message.') 387 while pos < endpoint: 388 value_start_pos = pos 389 (element, pos) = _DecodeSignedVarint32(buffer, pos) 390 # pylint: disable=protected-access 391 if element in enum_type.values_by_number: 392 value.append(element) 393 else: 394 if not message._unknown_fields: 395 message._unknown_fields = [] 396 tag_bytes = encoder.TagBytes(field_number, 397 wire_format.WIRETYPE_VARINT) 398 399 message._unknown_fields.append( 400 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 401 # pylint: enable=protected-access 402 if pos > endpoint: 403 if element in enum_type.values_by_number: 404 del value[-1] # Discard corrupt value. 405 else: 406 del message._unknown_fields[-1] 407 # pylint: enable=protected-access 408 raise _DecodeError('Packed element was truncated.') 409 return pos 410 return DecodePackedField 411 elif is_repeated: 412 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 413 tag_len = len(tag_bytes) 414 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 415 """Decode serialized repeated enum to its value and a new position. 416 417 Args: 418 buffer: memoryview of the serialized bytes. 419 pos: int, position in the memory view to start at. 420 end: int, end position of serialized data 421 message: Message object to store unknown fields in 422 field_dict: Map[Descriptor, Any] to store decoded values in. 423 424 Returns: 425 int, new position in serialized data. 426 """ 427 value = field_dict.get(key) 428 if value is None: 429 value = field_dict.setdefault(key, new_default(message)) 430 while 1: 431 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 432 # pylint: disable=protected-access 433 if element in enum_type.values_by_number: 434 value.append(element) 435 else: 436 if not message._unknown_fields: 437 message._unknown_fields = [] 438 message._unknown_fields.append( 439 (tag_bytes, buffer[pos:new_pos].tobytes())) 440 # pylint: enable=protected-access 441 # Predict that the next tag is another copy of the same repeated 442 # field. 443 pos = new_pos + tag_len 444 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 445 # Prediction failed. Return. 446 if new_pos > end: 447 raise _DecodeError('Truncated message.') 448 return new_pos 449 return DecodeRepeatedField 450 else: 451 def DecodeField(buffer, pos, end, message, field_dict): 452 """Decode serialized repeated enum to its value and a new position. 453 454 Args: 455 buffer: memoryview of the serialized bytes. 456 pos: int, position in the memory view to start at. 457 end: int, end position of serialized data 458 message: Message object to store unknown fields in 459 field_dict: Map[Descriptor, Any] to store decoded values in. 460 461 Returns: 462 int, new position in serialized data. 463 """ 464 value_start_pos = pos 465 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 466 if pos > end: 467 raise _DecodeError('Truncated message.') 468 if clear_if_default and not enum_value: 469 field_dict.pop(key, None) 470 return pos 471 # pylint: disable=protected-access 472 if enum_value in enum_type.values_by_number: 473 field_dict[key] = enum_value 474 else: 475 if not message._unknown_fields: 476 message._unknown_fields = [] 477 tag_bytes = encoder.TagBytes(field_number, 478 wire_format.WIRETYPE_VARINT) 479 message._unknown_fields.append( 480 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 481 # pylint: enable=protected-access 482 return pos 483 return DecodeField 484 485 486# -------------------------------------------------------------------- 487 488 489Int32Decoder = _SimpleDecoder( 490 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 491 492Int64Decoder = _SimpleDecoder( 493 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 494 495UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 496UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 497 498SInt32Decoder = _ModifiedDecoder( 499 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 500SInt64Decoder = _ModifiedDecoder( 501 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 502 503# Note that Python conveniently guarantees that when using the '<' prefix on 504# formats, they will also have the same size across all platforms (as opposed 505# to without the prefix, where their sizes depend on the C compiler's basic 506# type sizes). 507Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 508Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 509SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 510SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 511FloatDecoder = _FloatDecoder() 512DoubleDecoder = _DoubleDecoder() 513 514BoolDecoder = _ModifiedDecoder( 515 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 516 517 518def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 519 clear_if_default=False): 520 """Returns a decoder for a string field.""" 521 522 local_DecodeVarint = _DecodeVarint 523 524 def _ConvertToUnicode(memview): 525 """Convert byte to unicode.""" 526 byte_str = memview.tobytes() 527 try: 528 value = str(byte_str, 'utf-8') 529 except UnicodeDecodeError as e: 530 # add more information to the error message and re-raise it. 531 e.reason = '%s in field: %s' % (e, key.full_name) 532 raise 533 534 return value 535 536 assert not is_packed 537 if is_repeated: 538 tag_bytes = encoder.TagBytes(field_number, 539 wire_format.WIRETYPE_LENGTH_DELIMITED) 540 tag_len = len(tag_bytes) 541 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 542 value = field_dict.get(key) 543 if value is None: 544 value = field_dict.setdefault(key, new_default(message)) 545 while 1: 546 (size, pos) = local_DecodeVarint(buffer, pos) 547 new_pos = pos + size 548 if new_pos > end: 549 raise _DecodeError('Truncated string.') 550 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 551 # Predict that the next tag is another copy of the same repeated field. 552 pos = new_pos + tag_len 553 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 554 # Prediction failed. Return. 555 return new_pos 556 return DecodeRepeatedField 557 else: 558 def DecodeField(buffer, pos, end, message, field_dict): 559 (size, pos) = local_DecodeVarint(buffer, pos) 560 new_pos = pos + size 561 if new_pos > end: 562 raise _DecodeError('Truncated string.') 563 if clear_if_default and not size: 564 field_dict.pop(key, None) 565 else: 566 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 567 return new_pos 568 return DecodeField 569 570 571def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, 572 clear_if_default=False): 573 """Returns a decoder for a bytes field.""" 574 575 local_DecodeVarint = _DecodeVarint 576 577 assert not is_packed 578 if is_repeated: 579 tag_bytes = encoder.TagBytes(field_number, 580 wire_format.WIRETYPE_LENGTH_DELIMITED) 581 tag_len = len(tag_bytes) 582 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 583 value = field_dict.get(key) 584 if value is None: 585 value = field_dict.setdefault(key, new_default(message)) 586 while 1: 587 (size, pos) = local_DecodeVarint(buffer, pos) 588 new_pos = pos + size 589 if new_pos > end: 590 raise _DecodeError('Truncated string.') 591 value.append(buffer[pos:new_pos].tobytes()) 592 # Predict that the next tag is another copy of the same repeated field. 593 pos = new_pos + tag_len 594 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 595 # Prediction failed. Return. 596 return new_pos 597 return DecodeRepeatedField 598 else: 599 def DecodeField(buffer, pos, end, message, field_dict): 600 (size, pos) = local_DecodeVarint(buffer, pos) 601 new_pos = pos + size 602 if new_pos > end: 603 raise _DecodeError('Truncated string.') 604 if clear_if_default and not size: 605 field_dict.pop(key, None) 606 else: 607 field_dict[key] = buffer[pos:new_pos].tobytes() 608 return new_pos 609 return DecodeField 610 611 612def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 613 """Returns a decoder for a group field.""" 614 615 end_tag_bytes = encoder.TagBytes(field_number, 616 wire_format.WIRETYPE_END_GROUP) 617 end_tag_len = len(end_tag_bytes) 618 619 assert not is_packed 620 if is_repeated: 621 tag_bytes = encoder.TagBytes(field_number, 622 wire_format.WIRETYPE_START_GROUP) 623 tag_len = len(tag_bytes) 624 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 625 value = field_dict.get(key) 626 if value is None: 627 value = field_dict.setdefault(key, new_default(message)) 628 while 1: 629 value = field_dict.get(key) 630 if value is None: 631 value = field_dict.setdefault(key, new_default(message)) 632 # Read sub-message. 633 current_depth += 1 634 if current_depth > _recursion_limit: 635 raise _DecodeError( 636 'Error parsing message: too many levels of nesting.' 637 ) 638 pos = value.add()._InternalParse(buffer, pos, end) 639 current_depth -= 1 640 # Read end tag. 641 new_pos = pos+end_tag_len 642 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 643 raise _DecodeError('Missing group end tag.') 644 # Predict that the next tag is another copy of the same repeated field. 645 pos = new_pos + tag_len 646 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 647 # Prediction failed. Return. 648 return new_pos 649 return DecodeRepeatedField 650 else: 651 def DecodeField(buffer, pos, end, message, field_dict): 652 value = field_dict.get(key) 653 if value is None: 654 value = field_dict.setdefault(key, new_default(message)) 655 # Read sub-message. 656 current_depth += 1 657 if current_depth > _recursion_limit: 658 raise _DecodeError('Error parsing message: too many levels of nesting.') 659 pos = value._InternalParse(buffer, pos, end) 660 current_depth -= 1 661 # Read end tag. 662 new_pos = pos+end_tag_len 663 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 664 raise _DecodeError('Missing group end tag.') 665 return new_pos 666 return DecodeField 667 668 669def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 670 """Returns a decoder for a message field.""" 671 672 local_DecodeVarint = _DecodeVarint 673 674 assert not is_packed 675 if is_repeated: 676 tag_bytes = encoder.TagBytes(field_number, 677 wire_format.WIRETYPE_LENGTH_DELIMITED) 678 tag_len = len(tag_bytes) 679 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 680 value = field_dict.get(key) 681 if value is None: 682 value = field_dict.setdefault(key, new_default(message)) 683 while 1: 684 # Read length. 685 (size, pos) = local_DecodeVarint(buffer, pos) 686 new_pos = pos + size 687 if new_pos > end: 688 raise _DecodeError('Truncated message.') 689 # Read sub-message. 690 current_depth += 1 691 if current_depth > _recursion_limit: 692 raise _DecodeError( 693 'Error parsing message: too many levels of nesting.' 694 ) 695 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 696 # The only reason _InternalParse would return early is if it 697 # encountered an end-group tag. 698 raise _DecodeError('Unexpected end-group tag.') 699 current_depth -= 1 700 # Predict that the next tag is another copy of the same repeated field. 701 pos = new_pos + tag_len 702 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 703 # Prediction failed. Return. 704 return new_pos 705 return DecodeRepeatedField 706 else: 707 def DecodeField(buffer, pos, end, message, field_dict): 708 value = field_dict.get(key) 709 if value is None: 710 value = field_dict.setdefault(key, new_default(message)) 711 # Read length. 712 (size, pos) = local_DecodeVarint(buffer, pos) 713 new_pos = pos + size 714 if new_pos > end: 715 raise _DecodeError('Truncated message.') 716 # Read sub-message. 717 current_depth += 1 718 if current_depth > _recursion_limit: 719 raise _DecodeError('Error parsing message: too many levels of nesting.') 720 if value._InternalParse(buffer, pos, new_pos) != new_pos: 721 # The only reason _InternalParse would return early is if it encountered 722 # an end-group tag. 723 raise _DecodeError('Unexpected end-group tag.') 724 current_depth -= 1 725 return new_pos 726 return DecodeField 727 728 729# -------------------------------------------------------------------- 730 731MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 732 733def MessageSetItemDecoder(descriptor): 734 """Returns a decoder for a MessageSet item. 735 736 The parameter is the message Descriptor. 737 738 The message set message looks like this: 739 message MessageSet { 740 repeated group Item = 1 { 741 required int32 type_id = 2; 742 required string message = 3; 743 } 744 } 745 """ 746 747 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 748 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 749 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 750 751 local_ReadTag = ReadTag 752 local_DecodeVarint = _DecodeVarint 753 local_SkipField = SkipField 754 755 def DecodeItem(buffer, pos, end, message, field_dict): 756 """Decode serialized message set to its value and new position. 757 758 Args: 759 buffer: memoryview of the serialized bytes. 760 pos: int, position in the memory view to start at. 761 end: int, end position of serialized data 762 message: Message object to store unknown fields in 763 field_dict: Map[Descriptor, Any] to store decoded values in. 764 765 Returns: 766 int, new position in serialized data. 767 """ 768 message_set_item_start = pos 769 type_id = -1 770 message_start = -1 771 message_end = -1 772 773 # Technically, type_id and message can appear in any order, so we need 774 # a little loop here. 775 while 1: 776 (tag_bytes, pos) = local_ReadTag(buffer, pos) 777 if tag_bytes == type_id_tag_bytes: 778 (type_id, pos) = local_DecodeVarint(buffer, pos) 779 elif tag_bytes == message_tag_bytes: 780 (size, message_start) = local_DecodeVarint(buffer, pos) 781 pos = message_end = message_start + size 782 elif tag_bytes == item_end_tag_bytes: 783 break 784 else: 785 pos = SkipField(buffer, pos, end, tag_bytes) 786 if pos == -1: 787 raise _DecodeError('Missing group end tag.') 788 789 if pos > end: 790 raise _DecodeError('Truncated message.') 791 792 if type_id == -1: 793 raise _DecodeError('MessageSet item missing type_id.') 794 if message_start == -1: 795 raise _DecodeError('MessageSet item missing message.') 796 797 extension = message.Extensions._FindExtensionByNumber(type_id) 798 # pylint: disable=protected-access 799 if extension is not None: 800 value = field_dict.get(extension) 801 if value is None: 802 message_type = extension.message_type 803 if not hasattr(message_type, '_concrete_class'): 804 message_factory.GetMessageClass(message_type) 805 value = field_dict.setdefault( 806 extension, message_type._concrete_class()) 807 if value._InternalParse(buffer, message_start,message_end) != message_end: 808 # The only reason _InternalParse would return early is if it encountered 809 # an end-group tag. 810 raise _DecodeError('Unexpected end-group tag.') 811 else: 812 if not message._unknown_fields: 813 message._unknown_fields = [] 814 message._unknown_fields.append( 815 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 816 # pylint: enable=protected-access 817 818 return pos 819 820 return DecodeItem 821 822 823def UnknownMessageSetItemDecoder(): 824 """Returns a decoder for a Unknown MessageSet item.""" 825 826 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 827 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 828 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 829 830 def DecodeUnknownItem(buffer): 831 pos = 0 832 end = len(buffer) 833 message_start = -1 834 message_end = -1 835 while 1: 836 (tag_bytes, pos) = ReadTag(buffer, pos) 837 if tag_bytes == type_id_tag_bytes: 838 (type_id, pos) = _DecodeVarint(buffer, pos) 839 elif tag_bytes == message_tag_bytes: 840 (size, message_start) = _DecodeVarint(buffer, pos) 841 pos = message_end = message_start + size 842 elif tag_bytes == item_end_tag_bytes: 843 break 844 else: 845 pos = SkipField(buffer, pos, end, tag_bytes) 846 if pos == -1: 847 raise _DecodeError('Missing group end tag.') 848 849 if pos > end: 850 raise _DecodeError('Truncated message.') 851 852 if type_id == -1: 853 raise _DecodeError('MessageSet item missing type_id.') 854 if message_start == -1: 855 raise _DecodeError('MessageSet item missing message.') 856 857 return (type_id, buffer[message_start:message_end].tobytes()) 858 859 return DecodeUnknownItem 860 861# -------------------------------------------------------------------- 862 863def MapDecoder(field_descriptor, new_default, is_message_map): 864 """Returns a decoder for a map field.""" 865 866 key = field_descriptor 867 tag_bytes = encoder.TagBytes(field_descriptor.number, 868 wire_format.WIRETYPE_LENGTH_DELIMITED) 869 tag_len = len(tag_bytes) 870 local_DecodeVarint = _DecodeVarint 871 # Can't read _concrete_class yet; might not be initialized. 872 message_type = field_descriptor.message_type 873 874 def DecodeMap(buffer, pos, end, message, field_dict): 875 submsg = message_type._concrete_class() 876 value = field_dict.get(key) 877 if value is None: 878 value = field_dict.setdefault(key, new_default(message)) 879 while 1: 880 # Read length. 881 (size, pos) = local_DecodeVarint(buffer, pos) 882 new_pos = pos + size 883 if new_pos > end: 884 raise _DecodeError('Truncated message.') 885 # Read sub-message. 886 submsg.Clear() 887 if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 888 # The only reason _InternalParse would return early is if it 889 # encountered an end-group tag. 890 raise _DecodeError('Unexpected end-group tag.') 891 892 if is_message_map: 893 value[submsg.key].CopyFrom(submsg.value) 894 else: 895 value[submsg.key] = submsg.value 896 897 # Predict that the next tag is another copy of the same repeated field. 898 pos = new_pos + tag_len 899 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 900 # Prediction failed. Return. 901 return new_pos 902 903 return DecodeMap 904 905# -------------------------------------------------------------------- 906# Optimization is not as heavy here because calls to SkipField() are rare, 907# except for handling end-group tags. 908 909def _SkipVarint(buffer, pos, end): 910 """Skip a varint value. Returns the new position.""" 911 # Previously ord(buffer[pos]) raised IndexError when pos is out of range. 912 # With this code, ord(b'') raises TypeError. Both are handled in 913 # python_message.py to generate a 'Truncated message' error. 914 while ord(buffer[pos:pos+1].tobytes()) & 0x80: 915 pos += 1 916 pos += 1 917 if pos > end: 918 raise _DecodeError('Truncated message.') 919 return pos 920 921def _SkipFixed64(buffer, pos, end): 922 """Skip a fixed64 value. Returns the new position.""" 923 924 pos += 8 925 if pos > end: 926 raise _DecodeError('Truncated message.') 927 return pos 928 929 930def _DecodeFixed64(buffer, pos): 931 """Decode a fixed64.""" 932 new_pos = pos + 8 933 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 934 935 936def _SkipLengthDelimited(buffer, pos, end): 937 """Skip a length-delimited value. Returns the new position.""" 938 939 (size, pos) = _DecodeVarint(buffer, pos) 940 pos += size 941 if pos > end: 942 raise _DecodeError('Truncated message.') 943 return pos 944 945 946def _SkipGroup(buffer, pos, end): 947 """Skip sub-group. Returns the new position.""" 948 949 while 1: 950 (tag_bytes, pos) = ReadTag(buffer, pos) 951 new_pos = SkipField(buffer, pos, end, tag_bytes) 952 if new_pos == -1: 953 return pos 954 pos = new_pos 955 956 957def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): 958 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" 959 960 unknown_field_set = containers.UnknownFieldSet() 961 while end_pos is None or pos < end_pos: 962 (tag_bytes, pos) = ReadTag(buffer, pos) 963 (tag, _) = _DecodeVarint(tag_bytes, 0) 964 field_number, wire_type = wire_format.UnpackTag(tag) 965 if wire_type == wire_format.WIRETYPE_END_GROUP: 966 break 967 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) 968 # pylint: disable=protected-access 969 unknown_field_set._add(field_number, wire_type, data) 970 971 return (unknown_field_set, pos) 972 973 974def _DecodeUnknownField(buffer, pos, wire_type): 975 """Decode a unknown field. Returns the UnknownField and new position.""" 976 977 if wire_type == wire_format.WIRETYPE_VARINT: 978 (data, pos) = _DecodeVarint(buffer, pos) 979 elif wire_type == wire_format.WIRETYPE_FIXED64: 980 (data, pos) = _DecodeFixed64(buffer, pos) 981 elif wire_type == wire_format.WIRETYPE_FIXED32: 982 (data, pos) = _DecodeFixed32(buffer, pos) 983 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 984 (size, pos) = _DecodeVarint(buffer, pos) 985 data = buffer[pos:pos+size].tobytes() 986 pos += size 987 elif wire_type == wire_format.WIRETYPE_START_GROUP: 988 current_depth += 1 989 if current_depth >= _recursion_limit: 990 raise _DecodeError('Error parsing message: too many levels of nesting.') 991 (data, pos) = _DecodeUnknownFieldSet(buffer, pos) 992 current_depth -= 1 993 elif wire_type == wire_format.WIRETYPE_END_GROUP: 994 return (0, -1) 995 else: 996 raise _DecodeError('Wrong wire type in tag.') 997 998 return (data, pos) 999 1000 1001def _EndGroup(buffer, pos, end): 1002 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" 1003 1004 return -1 1005 1006 1007def _SkipFixed32(buffer, pos, end): 1008 """Skip a fixed32 value. Returns the new position.""" 1009 1010 pos += 4 1011 if pos > end: 1012 raise _DecodeError('Truncated message.') 1013 return pos 1014 1015 1016def _DecodeFixed32(buffer, pos): 1017 """Decode a fixed32.""" 1018 1019 new_pos = pos + 4 1020 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 1021 1022DEFAULT_RECURSION_LIMIT = 100 1023_recursion_limit = DEFAULT_RECURSION_LIMIT 1024 1025 1026def SetRecursionLimit(new_limit): 1027 global _recursion_limit 1028 _recursion_limit = new_limit 1029 1030 1031def _RaiseInvalidWireType(buffer, pos, end): 1032 """Skip function for unknown wire types. Raises an exception.""" 1033 1034 raise _DecodeError('Tag had invalid wire type.') 1035 1036def _FieldSkipper(): 1037 """Constructs the SkipField function.""" 1038 1039 WIRETYPE_TO_SKIPPER = [ 1040 _SkipVarint, 1041 _SkipFixed64, 1042 _SkipLengthDelimited, 1043 _SkipGroup, 1044 _EndGroup, 1045 _SkipFixed32, 1046 _RaiseInvalidWireType, 1047 _RaiseInvalidWireType, 1048 ] 1049 1050 wiretype_mask = wire_format.TAG_TYPE_MASK 1051 1052 def SkipField(buffer, pos, end, tag_bytes): 1053 """Skips a field with the specified tag. 1054 1055 |pos| should point to the byte immediately after the tag. 1056 1057 Returns: 1058 The new position (after the tag value), or -1 if the tag is an end-group 1059 tag (in which case the calling loop should break). 1060 """ 1061 1062 # The wire type is always in the first byte since varints are little-endian. 1063 wire_type = ord(tag_bytes[0:1]) & wiretype_mask 1064 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 1065 1066 return SkipField 1067 1068SkipField = _FieldSkipper() 1069