1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Code for decoding protocol buffer primitives. 32 33This code is very similar to encoder.py -- read the docs for that module first. 34 35A "decoder" is a function with the signature: 36 Decode(buffer, pos, end, message, field_dict) 37The arguments are: 38 buffer: The string containing the encoded message. 39 pos: The current position in the string. 40 end: The position in the string where the current message ends. May be 41 less than len(buffer) if we're reading a sub-message. 42 message: The message object into which we're parsing. 43 field_dict: message._fields (avoids a hashtable lookup). 44The decoder reads the field and stores it into field_dict, returning the new 45buffer position. A decoder for a repeated field may proactively decode all of 46the elements of that field, if they appear consecutively. 47 48Note that decoders may throw any of the following: 49 IndexError: Indicates a truncated message. 50 struct.error: Unpacking of a fixed-width field failed. 51 message.DecodeError: Other errors. 52 53Decoders are expected to raise an exception if they are called with pos > end. 54This allows callers to be lax about bounds checking: it's fineto read past 55"end" as long as you are sure that someone else will notice and throw an 56exception later on. 57 58Something up the call stack is expected to catch IndexError and struct.error 59and convert them to message.DecodeError. 60 61Decoders are constructed using decoder constructors with the signature: 62 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 63The arguments are: 64 field_number: The field number of the field we want to decode. 65 is_repeated: Is the field a repeated field? (bool) 66 is_packed: Is the field a packed field? (bool) 67 key: The key to use when looking up the field within field_dict. 68 (This is actually the FieldDescriptor but nothing in this 69 file should depend on that.) 70 new_default: A function which takes a message object as a parameter and 71 returns a new instance of the default value for this field. 72 (This is called for repeated fields and sub-messages, when an 73 instance does not already exist.) 74 75As with encoders, we define a decoder constructor for every type of field. 76Then, for every field of every message class we construct an actual decoder. 77That decoder goes into a dict indexed by tag, so when we decode a message 78we repeatedly read a tag, look up the corresponding decoder, and invoke it. 79""" 80 81__author__ = 'kenton@google.com (Kenton Varda)' 82 83import struct 84import sys 85import six 86 87_UCS2_MAXUNICODE = 65535 88if six.PY3: 89 long = int 90else: 91 import re # pylint: disable=g-import-not-at-top 92 _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]')) 93 94from google.protobuf.internal import containers 95from google.protobuf.internal import encoder 96from google.protobuf.internal import wire_format 97from google.protobuf import message 98 99 100# This will overflow and thus become IEEE-754 "infinity". We would use 101# "float('inf')" but it doesn't work on Windows pre-Python-2.6. 102_POS_INF = 1e10000 103_NEG_INF = -_POS_INF 104_NAN = _POS_INF * 0 105 106 107# This is not for optimization, but rather to avoid conflicts with local 108# variables named "message". 109_DecodeError = message.DecodeError 110 111 112def _VarintDecoder(mask, result_type): 113 """Return an encoder for a basic varint value (does not include tag). 114 115 Decoded values will be bitwise-anded with the given mask before being 116 returned, e.g. to limit them to 32 bits. The returned decoder does not 117 take the usual "end" parameter -- the caller is expected to do bounds checking 118 after the fact (often the caller can defer such checking until later). The 119 decoder returns a (value, new_pos) pair. 120 """ 121 122 def DecodeVarint(buffer, pos): 123 result = 0 124 shift = 0 125 while 1: 126 b = six.indexbytes(buffer, pos) 127 result |= ((b & 0x7f) << shift) 128 pos += 1 129 if not (b & 0x80): 130 result &= mask 131 result = result_type(result) 132 return (result, pos) 133 shift += 7 134 if shift >= 64: 135 raise _DecodeError('Too many bytes when decoding varint.') 136 return DecodeVarint 137 138 139def _SignedVarintDecoder(bits, result_type): 140 """Like _VarintDecoder() but decodes signed values.""" 141 142 signbit = 1 << (bits - 1) 143 mask = (1 << bits) - 1 144 145 def DecodeVarint(buffer, pos): 146 result = 0 147 shift = 0 148 while 1: 149 b = six.indexbytes(buffer, pos) 150 result |= ((b & 0x7f) << shift) 151 pos += 1 152 if not (b & 0x80): 153 result &= mask 154 result = (result ^ signbit) - signbit 155 result = result_type(result) 156 return (result, pos) 157 shift += 7 158 if shift >= 64: 159 raise _DecodeError('Too many bytes when decoding varint.') 160 return DecodeVarint 161 162# We force 32-bit values to int and 64-bit values to long to make 163# alternate implementations where the distinction is more significant 164# (e.g. the C++ implementation) simpler. 165 166_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) 167_DecodeSignedVarint = _SignedVarintDecoder(64, long) 168 169# Use these versions for values which must be limited to 32 bits. 170_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 171_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 172 173 174def ReadTag(buffer, pos): 175 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 176 177 We return the raw bytes of the tag rather than decoding them. The raw 178 bytes can then be used to look up the proper decoder. This effectively allows 179 us to trade some work that would be done in pure-python (decoding a varint) 180 for work that is done in C (searching for a byte string in a hash table). 181 In a low-level language it would be much cheaper to decode the varint and 182 use that, but not in Python. 183 184 Args: 185 buffer: memoryview object of the encoded bytes 186 pos: int of the current position to start from 187 188 Returns: 189 Tuple[bytes, int] of the tag data and new position. 190 """ 191 start = pos 192 while six.indexbytes(buffer, pos) & 0x80: 193 pos += 1 194 pos += 1 195 196 tag_bytes = buffer[start:pos].tobytes() 197 return tag_bytes, pos 198 199 200# -------------------------------------------------------------------- 201 202 203def _SimpleDecoder(wire_type, decode_value): 204 """Return a constructor for a decoder for fields of a particular type. 205 206 Args: 207 wire_type: The field's wire type. 208 decode_value: A function which decodes an individual value, e.g. 209 _DecodeVarint() 210 """ 211 212 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default): 213 if is_packed: 214 local_DecodeVarint = _DecodeVarint 215 def DecodePackedField(buffer, pos, end, message, field_dict): 216 value = field_dict.get(key) 217 if value is None: 218 value = field_dict.setdefault(key, new_default(message)) 219 (endpoint, pos) = local_DecodeVarint(buffer, pos) 220 endpoint += pos 221 if endpoint > end: 222 raise _DecodeError('Truncated message.') 223 while pos < endpoint: 224 (element, pos) = decode_value(buffer, pos) 225 value.append(element) 226 if pos > endpoint: 227 del value[-1] # Discard corrupt value. 228 raise _DecodeError('Packed element was truncated.') 229 return pos 230 return DecodePackedField 231 elif is_repeated: 232 tag_bytes = encoder.TagBytes(field_number, wire_type) 233 tag_len = len(tag_bytes) 234 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 235 value = field_dict.get(key) 236 if value is None: 237 value = field_dict.setdefault(key, new_default(message)) 238 while 1: 239 (element, new_pos) = decode_value(buffer, pos) 240 value.append(element) 241 # Predict that the next tag is another copy of the same repeated 242 # field. 243 pos = new_pos + tag_len 244 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 245 # Prediction failed. Return. 246 if new_pos > end: 247 raise _DecodeError('Truncated message.') 248 return new_pos 249 return DecodeRepeatedField 250 else: 251 def DecodeField(buffer, pos, end, message, field_dict): 252 (field_dict[key], pos) = decode_value(buffer, pos) 253 if pos > end: 254 del field_dict[key] # Discard corrupt value. 255 raise _DecodeError('Truncated message.') 256 return pos 257 return DecodeField 258 259 return SpecificDecoder 260 261 262def _ModifiedDecoder(wire_type, decode_value, modify_value): 263 """Like SimpleDecoder but additionally invokes modify_value on every value 264 before storing it. Usually modify_value is ZigZagDecode. 265 """ 266 267 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 268 # not enough to make a significant difference. 269 270 def InnerDecode(buffer, pos): 271 (result, new_pos) = decode_value(buffer, pos) 272 return (modify_value(result), new_pos) 273 return _SimpleDecoder(wire_type, InnerDecode) 274 275 276def _StructPackDecoder(wire_type, format): 277 """Return a constructor for a decoder for a fixed-width field. 278 279 Args: 280 wire_type: The field's wire type. 281 format: The format string to pass to struct.unpack(). 282 """ 283 284 value_size = struct.calcsize(format) 285 local_unpack = struct.unpack 286 287 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 288 # not enough to make a significant difference. 289 290 # Note that we expect someone up-stack to catch struct.error and convert 291 # it to _DecodeError -- this way we don't have to set up exception- 292 # handling blocks every time we parse one value. 293 294 def InnerDecode(buffer, pos): 295 new_pos = pos + value_size 296 result = local_unpack(format, buffer[pos:new_pos])[0] 297 return (result, new_pos) 298 return _SimpleDecoder(wire_type, InnerDecode) 299 300 301def _FloatDecoder(): 302 """Returns a decoder for a float field. 303 304 This code works around a bug in struct.unpack for non-finite 32-bit 305 floating-point values. 306 """ 307 308 local_unpack = struct.unpack 309 310 def InnerDecode(buffer, pos): 311 """Decode serialized float to a float and new position. 312 313 Args: 314 buffer: memoryview of the serialized bytes 315 pos: int, position in the memory view to start at. 316 317 Returns: 318 Tuple[float, int] of the deserialized float value and new position 319 in the serialized data. 320 """ 321 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 322 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 323 new_pos = pos + 4 324 float_bytes = buffer[pos:new_pos].tobytes() 325 326 # If this value has all its exponent bits set, then it's non-finite. 327 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 328 # To avoid that, we parse it specially. 329 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 330 # If at least one significand bit is set... 331 if float_bytes[0:3] != b'\x00\x00\x80': 332 return (_NAN, new_pos) 333 # If sign bit is set... 334 if float_bytes[3:4] == b'\xFF': 335 return (_NEG_INF, new_pos) 336 return (_POS_INF, new_pos) 337 338 # Note that we expect someone up-stack to catch struct.error and convert 339 # it to _DecodeError -- this way we don't have to set up exception- 340 # handling blocks every time we parse one value. 341 result = local_unpack('<f', float_bytes)[0] 342 return (result, new_pos) 343 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 344 345 346def _DoubleDecoder(): 347 """Returns a decoder for a double field. 348 349 This code works around a bug in struct.unpack for not-a-number. 350 """ 351 352 local_unpack = struct.unpack 353 354 def InnerDecode(buffer, pos): 355 """Decode serialized double to a double and new position. 356 357 Args: 358 buffer: memoryview of the serialized bytes. 359 pos: int, position in the memory view to start at. 360 361 Returns: 362 Tuple[float, int] of the decoded double value and new position 363 in the serialized data. 364 """ 365 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 366 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 367 new_pos = pos + 8 368 double_bytes = buffer[pos:new_pos].tobytes() 369 370 # If this value has all its exponent bits set and at least one significand 371 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 372 # as inf or -inf. To avoid that, we treat it specially. 373 if ((double_bytes[7:8] in b'\x7F\xFF') 374 and (double_bytes[6:7] >= b'\xF0') 375 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 376 return (_NAN, new_pos) 377 378 # Note that we expect someone up-stack to catch struct.error and convert 379 # it to _DecodeError -- this way we don't have to set up exception- 380 # handling blocks every time we parse one value. 381 result = local_unpack('<d', double_bytes)[0] 382 return (result, new_pos) 383 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 384 385 386def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): 387 enum_type = key.enum_type 388 if is_packed: 389 local_DecodeVarint = _DecodeVarint 390 def DecodePackedField(buffer, pos, end, message, field_dict): 391 """Decode serialized packed enum to its value and a new position. 392 393 Args: 394 buffer: memoryview of the serialized bytes. 395 pos: int, position in the memory view to start at. 396 end: int, end position of serialized data 397 message: Message object to store unknown fields in 398 field_dict: Map[Descriptor, Any] to store decoded values in. 399 400 Returns: 401 int, new position in serialized data. 402 """ 403 value = field_dict.get(key) 404 if value is None: 405 value = field_dict.setdefault(key, new_default(message)) 406 (endpoint, pos) = local_DecodeVarint(buffer, pos) 407 endpoint += pos 408 if endpoint > end: 409 raise _DecodeError('Truncated message.') 410 while pos < endpoint: 411 value_start_pos = pos 412 (element, pos) = _DecodeSignedVarint32(buffer, pos) 413 # pylint: disable=protected-access 414 if element in enum_type.values_by_number: 415 value.append(element) 416 else: 417 if not message._unknown_fields: 418 message._unknown_fields = [] 419 tag_bytes = encoder.TagBytes(field_number, 420 wire_format.WIRETYPE_VARINT) 421 422 message._unknown_fields.append( 423 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 424 # pylint: enable=protected-access 425 if pos > endpoint: 426 if element in enum_type.values_by_number: 427 del value[-1] # Discard corrupt value. 428 else: 429 del message._unknown_fields[-1] 430 raise _DecodeError('Packed element was truncated.') 431 return pos 432 return DecodePackedField 433 elif is_repeated: 434 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 435 tag_len = len(tag_bytes) 436 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 437 """Decode serialized repeated enum to its value and a new position. 438 439 Args: 440 buffer: memoryview of the serialized bytes. 441 pos: int, position in the memory view to start at. 442 end: int, end position of serialized data 443 message: Message object to store unknown fields in 444 field_dict: Map[Descriptor, Any] to store decoded values in. 445 446 Returns: 447 int, new position in serialized data. 448 """ 449 value = field_dict.get(key) 450 if value is None: 451 value = field_dict.setdefault(key, new_default(message)) 452 while 1: 453 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 454 # pylint: disable=protected-access 455 if element in enum_type.values_by_number: 456 value.append(element) 457 else: 458 if not message._unknown_fields: 459 message._unknown_fields = [] 460 message._unknown_fields.append( 461 (tag_bytes, buffer[pos:new_pos].tobytes())) 462 # pylint: enable=protected-access 463 # Predict that the next tag is another copy of the same repeated 464 # field. 465 pos = new_pos + tag_len 466 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 467 # Prediction failed. Return. 468 if new_pos > end: 469 raise _DecodeError('Truncated message.') 470 return new_pos 471 return DecodeRepeatedField 472 else: 473 def DecodeField(buffer, pos, end, message, field_dict): 474 """Decode serialized repeated enum to its value and a new position. 475 476 Args: 477 buffer: memoryview of the serialized bytes. 478 pos: int, position in the memory view to start at. 479 end: int, end position of serialized data 480 message: Message object to store unknown fields in 481 field_dict: Map[Descriptor, Any] to store decoded values in. 482 483 Returns: 484 int, new position in serialized data. 485 """ 486 value_start_pos = pos 487 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 488 if pos > end: 489 raise _DecodeError('Truncated message.') 490 # pylint: disable=protected-access 491 if enum_value in enum_type.values_by_number: 492 field_dict[key] = enum_value 493 else: 494 if not message._unknown_fields: 495 message._unknown_fields = [] 496 tag_bytes = encoder.TagBytes(field_number, 497 wire_format.WIRETYPE_VARINT) 498 message._unknown_fields.append( 499 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 500 # pylint: enable=protected-access 501 return pos 502 return DecodeField 503 504 505# -------------------------------------------------------------------- 506 507 508Int32Decoder = _SimpleDecoder( 509 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 510 511Int64Decoder = _SimpleDecoder( 512 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 513 514UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 515UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 516 517SInt32Decoder = _ModifiedDecoder( 518 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 519SInt64Decoder = _ModifiedDecoder( 520 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 521 522# Note that Python conveniently guarantees that when using the '<' prefix on 523# formats, they will also have the same size across all platforms (as opposed 524# to without the prefix, where their sizes depend on the C compiler's basic 525# type sizes). 526Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 527Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 528SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 529SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 530FloatDecoder = _FloatDecoder() 531DoubleDecoder = _DoubleDecoder() 532 533BoolDecoder = _ModifiedDecoder( 534 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 535 536 537def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 538 is_strict_utf8=False): 539 """Returns a decoder for a string field.""" 540 541 local_DecodeVarint = _DecodeVarint 542 local_unicode = six.text_type 543 544 def _ConvertToUnicode(memview): 545 """Convert byte to unicode.""" 546 byte_str = memview.tobytes() 547 try: 548 value = local_unicode(byte_str, 'utf-8') 549 except UnicodeDecodeError as e: 550 # add more information to the error message and re-raise it. 551 e.reason = '%s in field: %s' % (e, key.full_name) 552 raise 553 554 if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE: 555 # Only do the check for python2 ucs4 when is_strict_utf8 enabled 556 if _SURROGATE_PATTERN.search(value): 557 reason = ('String field %s contains invalid UTF-8 data when parsing' 558 'a protocol buffer: surrogates not allowed. Use' 559 'the bytes type if you intend to send raw bytes.') % ( 560 key.full_name) 561 raise message.DecodeError(reason) 562 563 return value 564 565 assert not is_packed 566 if is_repeated: 567 tag_bytes = encoder.TagBytes(field_number, 568 wire_format.WIRETYPE_LENGTH_DELIMITED) 569 tag_len = len(tag_bytes) 570 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 571 value = field_dict.get(key) 572 if value is None: 573 value = field_dict.setdefault(key, new_default(message)) 574 while 1: 575 (size, pos) = local_DecodeVarint(buffer, pos) 576 new_pos = pos + size 577 if new_pos > end: 578 raise _DecodeError('Truncated string.') 579 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 580 # Predict that the next tag is another copy of the same repeated field. 581 pos = new_pos + tag_len 582 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 583 # Prediction failed. Return. 584 return new_pos 585 return DecodeRepeatedField 586 else: 587 def DecodeField(buffer, pos, end, message, field_dict): 588 (size, pos) = local_DecodeVarint(buffer, pos) 589 new_pos = pos + size 590 if new_pos > end: 591 raise _DecodeError('Truncated string.') 592 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 593 return new_pos 594 return DecodeField 595 596 597def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): 598 """Returns a decoder for a bytes field.""" 599 600 local_DecodeVarint = _DecodeVarint 601 602 assert not is_packed 603 if is_repeated: 604 tag_bytes = encoder.TagBytes(field_number, 605 wire_format.WIRETYPE_LENGTH_DELIMITED) 606 tag_len = len(tag_bytes) 607 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 608 value = field_dict.get(key) 609 if value is None: 610 value = field_dict.setdefault(key, new_default(message)) 611 while 1: 612 (size, pos) = local_DecodeVarint(buffer, pos) 613 new_pos = pos + size 614 if new_pos > end: 615 raise _DecodeError('Truncated string.') 616 value.append(buffer[pos:new_pos].tobytes()) 617 # Predict that the next tag is another copy of the same repeated field. 618 pos = new_pos + tag_len 619 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 620 # Prediction failed. Return. 621 return new_pos 622 return DecodeRepeatedField 623 else: 624 def DecodeField(buffer, pos, end, message, field_dict): 625 (size, pos) = local_DecodeVarint(buffer, pos) 626 new_pos = pos + size 627 if new_pos > end: 628 raise _DecodeError('Truncated string.') 629 field_dict[key] = buffer[pos:new_pos].tobytes() 630 return new_pos 631 return DecodeField 632 633 634def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 635 """Returns a decoder for a group field.""" 636 637 end_tag_bytes = encoder.TagBytes(field_number, 638 wire_format.WIRETYPE_END_GROUP) 639 end_tag_len = len(end_tag_bytes) 640 641 assert not is_packed 642 if is_repeated: 643 tag_bytes = encoder.TagBytes(field_number, 644 wire_format.WIRETYPE_START_GROUP) 645 tag_len = len(tag_bytes) 646 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 647 value = field_dict.get(key) 648 if value is None: 649 value = field_dict.setdefault(key, new_default(message)) 650 while 1: 651 value = field_dict.get(key) 652 if value is None: 653 value = field_dict.setdefault(key, new_default(message)) 654 # Read sub-message. 655 pos = value.add()._InternalParse(buffer, pos, end) 656 # Read end tag. 657 new_pos = pos+end_tag_len 658 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 659 raise _DecodeError('Missing group end tag.') 660 # Predict that the next tag is another copy of the same repeated field. 661 pos = new_pos + tag_len 662 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 663 # Prediction failed. Return. 664 return new_pos 665 return DecodeRepeatedField 666 else: 667 def DecodeField(buffer, pos, end, message, field_dict): 668 value = field_dict.get(key) 669 if value is None: 670 value = field_dict.setdefault(key, new_default(message)) 671 # Read sub-message. 672 pos = value._InternalParse(buffer, pos, end) 673 # Read end tag. 674 new_pos = pos+end_tag_len 675 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 676 raise _DecodeError('Missing group end tag.') 677 return new_pos 678 return DecodeField 679 680 681def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 682 """Returns a decoder for a message field.""" 683 684 local_DecodeVarint = _DecodeVarint 685 686 assert not is_packed 687 if is_repeated: 688 tag_bytes = encoder.TagBytes(field_number, 689 wire_format.WIRETYPE_LENGTH_DELIMITED) 690 tag_len = len(tag_bytes) 691 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 692 value = field_dict.get(key) 693 if value is None: 694 value = field_dict.setdefault(key, new_default(message)) 695 while 1: 696 # Read length. 697 (size, pos) = local_DecodeVarint(buffer, pos) 698 new_pos = pos + size 699 if new_pos > end: 700 raise _DecodeError('Truncated message.') 701 # Read sub-message. 702 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 703 # The only reason _InternalParse would return early is if it 704 # encountered an end-group tag. 705 raise _DecodeError('Unexpected end-group tag.') 706 # Predict that the next tag is another copy of the same repeated field. 707 pos = new_pos + tag_len 708 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 709 # Prediction failed. Return. 710 return new_pos 711 return DecodeRepeatedField 712 else: 713 def DecodeField(buffer, pos, end, message, field_dict): 714 value = field_dict.get(key) 715 if value is None: 716 value = field_dict.setdefault(key, new_default(message)) 717 # Read length. 718 (size, pos) = local_DecodeVarint(buffer, pos) 719 new_pos = pos + size 720 if new_pos > end: 721 raise _DecodeError('Truncated message.') 722 # Read sub-message. 723 if value._InternalParse(buffer, pos, new_pos) != new_pos: 724 # The only reason _InternalParse would return early is if it encountered 725 # an end-group tag. 726 raise _DecodeError('Unexpected end-group tag.') 727 return new_pos 728 return DecodeField 729 730 731# -------------------------------------------------------------------- 732 733MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 734 735def MessageSetItemDecoder(descriptor): 736 """Returns a decoder for a MessageSet item. 737 738 The parameter is the message Descriptor. 739 740 The message set message looks like this: 741 message MessageSet { 742 repeated group Item = 1 { 743 required int32 type_id = 2; 744 required string message = 3; 745 } 746 } 747 """ 748 749 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 750 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 751 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 752 753 local_ReadTag = ReadTag 754 local_DecodeVarint = _DecodeVarint 755 local_SkipField = SkipField 756 757 def DecodeItem(buffer, pos, end, message, field_dict): 758 """Decode serialized message set to its value and new position. 759 760 Args: 761 buffer: memoryview of the serialized bytes. 762 pos: int, position in the memory view to start at. 763 end: int, end position of serialized data 764 message: Message object to store unknown fields in 765 field_dict: Map[Descriptor, Any] to store decoded values in. 766 767 Returns: 768 int, new position in serialized data. 769 """ 770 message_set_item_start = pos 771 type_id = -1 772 message_start = -1 773 message_end = -1 774 775 # Technically, type_id and message can appear in any order, so we need 776 # a little loop here. 777 while 1: 778 (tag_bytes, pos) = local_ReadTag(buffer, pos) 779 if tag_bytes == type_id_tag_bytes: 780 (type_id, pos) = local_DecodeVarint(buffer, pos) 781 elif tag_bytes == message_tag_bytes: 782 (size, message_start) = local_DecodeVarint(buffer, pos) 783 pos = message_end = message_start + size 784 elif tag_bytes == item_end_tag_bytes: 785 break 786 else: 787 pos = SkipField(buffer, pos, end, tag_bytes) 788 if pos == -1: 789 raise _DecodeError('Missing group end tag.') 790 791 if pos > end: 792 raise _DecodeError('Truncated message.') 793 794 if type_id == -1: 795 raise _DecodeError('MessageSet item missing type_id.') 796 if message_start == -1: 797 raise _DecodeError('MessageSet item missing message.') 798 799 extension = message.Extensions._FindExtensionByNumber(type_id) 800 # pylint: disable=protected-access 801 if extension is not None: 802 value = field_dict.get(extension) 803 if value is None: 804 value = field_dict.setdefault( 805 extension, extension.message_type._concrete_class()) 806 if value._InternalParse(buffer, message_start,message_end) != message_end: 807 # The only reason _InternalParse would return early is if it encountered 808 # an end-group tag. 809 raise _DecodeError('Unexpected end-group tag.') 810 else: 811 if not message._unknown_fields: 812 message._unknown_fields = [] 813 message._unknown_fields.append( 814 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 815 # pylint: enable=protected-access 816 817 return pos 818 819 return DecodeItem 820 821# -------------------------------------------------------------------- 822 823def MapDecoder(field_descriptor, new_default, is_message_map): 824 """Returns a decoder for a map field.""" 825 826 key = field_descriptor 827 tag_bytes = encoder.TagBytes(field_descriptor.number, 828 wire_format.WIRETYPE_LENGTH_DELIMITED) 829 tag_len = len(tag_bytes) 830 local_DecodeVarint = _DecodeVarint 831 # Can't read _concrete_class yet; might not be initialized. 832 message_type = field_descriptor.message_type 833 834 def DecodeMap(buffer, pos, end, message, field_dict): 835 submsg = message_type._concrete_class() 836 value = field_dict.get(key) 837 if value is None: 838 value = field_dict.setdefault(key, new_default(message)) 839 while 1: 840 # Read length. 841 (size, pos) = local_DecodeVarint(buffer, pos) 842 new_pos = pos + size 843 if new_pos > end: 844 raise _DecodeError('Truncated message.') 845 # Read sub-message. 846 submsg.Clear() 847 if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 848 # The only reason _InternalParse would return early is if it 849 # encountered an end-group tag. 850 raise _DecodeError('Unexpected end-group tag.') 851 852 if is_message_map: 853 value[submsg.key].MergeFrom(submsg.value) 854 else: 855 value[submsg.key] = submsg.value 856 857 # Predict that the next tag is another copy of the same repeated field. 858 pos = new_pos + tag_len 859 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 860 # Prediction failed. Return. 861 return new_pos 862 863 return DecodeMap 864 865# -------------------------------------------------------------------- 866# Optimization is not as heavy here because calls to SkipField() are rare, 867# except for handling end-group tags. 868 869def _SkipVarint(buffer, pos, end): 870 """Skip a varint value. Returns the new position.""" 871 # Previously ord(buffer[pos]) raised IndexError when pos is out of range. 872 # With this code, ord(b'') raises TypeError. Both are handled in 873 # python_message.py to generate a 'Truncated message' error. 874 while ord(buffer[pos:pos+1].tobytes()) & 0x80: 875 pos += 1 876 pos += 1 877 if pos > end: 878 raise _DecodeError('Truncated message.') 879 return pos 880 881def _SkipFixed64(buffer, pos, end): 882 """Skip a fixed64 value. Returns the new position.""" 883 884 pos += 8 885 if pos > end: 886 raise _DecodeError('Truncated message.') 887 return pos 888 889 890def _DecodeFixed64(buffer, pos): 891 """Decode a fixed64.""" 892 new_pos = pos + 8 893 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 894 895 896def _SkipLengthDelimited(buffer, pos, end): 897 """Skip a length-delimited value. Returns the new position.""" 898 899 (size, pos) = _DecodeVarint(buffer, pos) 900 pos += size 901 if pos > end: 902 raise _DecodeError('Truncated message.') 903 return pos 904 905 906def _SkipGroup(buffer, pos, end): 907 """Skip sub-group. Returns the new position.""" 908 909 while 1: 910 (tag_bytes, pos) = ReadTag(buffer, pos) 911 new_pos = SkipField(buffer, pos, end, tag_bytes) 912 if new_pos == -1: 913 return pos 914 pos = new_pos 915 916 917def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): 918 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" 919 920 unknown_field_set = containers.UnknownFieldSet() 921 while end_pos is None or pos < end_pos: 922 (tag_bytes, pos) = ReadTag(buffer, pos) 923 (tag, _) = _DecodeVarint(tag_bytes, 0) 924 field_number, wire_type = wire_format.UnpackTag(tag) 925 if wire_type == wire_format.WIRETYPE_END_GROUP: 926 break 927 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) 928 # pylint: disable=protected-access 929 unknown_field_set._add(field_number, wire_type, data) 930 931 return (unknown_field_set, pos) 932 933 934def _DecodeUnknownField(buffer, pos, wire_type): 935 """Decode a unknown field. Returns the UnknownField and new position.""" 936 937 if wire_type == wire_format.WIRETYPE_VARINT: 938 (data, pos) = _DecodeVarint(buffer, pos) 939 elif wire_type == wire_format.WIRETYPE_FIXED64: 940 (data, pos) = _DecodeFixed64(buffer, pos) 941 elif wire_type == wire_format.WIRETYPE_FIXED32: 942 (data, pos) = _DecodeFixed32(buffer, pos) 943 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 944 (size, pos) = _DecodeVarint(buffer, pos) 945 data = buffer[pos:pos+size] 946 pos += size 947 elif wire_type == wire_format.WIRETYPE_START_GROUP: 948 (data, pos) = _DecodeUnknownFieldSet(buffer, pos) 949 elif wire_type == wire_format.WIRETYPE_END_GROUP: 950 return (0, -1) 951 else: 952 raise _DecodeError('Wrong wire type in tag.') 953 954 return (data, pos) 955 956 957def _EndGroup(buffer, pos, end): 958 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" 959 960 return -1 961 962 963def _SkipFixed32(buffer, pos, end): 964 """Skip a fixed32 value. Returns the new position.""" 965 966 pos += 4 967 if pos > end: 968 raise _DecodeError('Truncated message.') 969 return pos 970 971 972def _DecodeFixed32(buffer, pos): 973 """Decode a fixed32.""" 974 975 new_pos = pos + 4 976 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 977 978 979def _RaiseInvalidWireType(buffer, pos, end): 980 """Skip function for unknown wire types. Raises an exception.""" 981 982 raise _DecodeError('Tag had invalid wire type.') 983 984def _FieldSkipper(): 985 """Constructs the SkipField function.""" 986 987 WIRETYPE_TO_SKIPPER = [ 988 _SkipVarint, 989 _SkipFixed64, 990 _SkipLengthDelimited, 991 _SkipGroup, 992 _EndGroup, 993 _SkipFixed32, 994 _RaiseInvalidWireType, 995 _RaiseInvalidWireType, 996 ] 997 998 wiretype_mask = wire_format.TAG_TYPE_MASK 999 1000 def SkipField(buffer, pos, end, tag_bytes): 1001 """Skips a field with the specified tag. 1002 1003 |pos| should point to the byte immediately after the tag. 1004 1005 Returns: 1006 The new position (after the tag value), or -1 if the tag is an end-group 1007 tag (in which case the calling loop should break). 1008 """ 1009 1010 # The wire type is always in the first byte since varints are little-endian. 1011 wire_type = ord(tag_bytes[0:1]) & wiretype_mask 1012 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 1013 1014 return SkipField 1015 1016SkipField = _FieldSkipper() 1017