1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Code for decoding protocol buffer primitives. 32 33This code is very similar to encoder.py -- read the docs for that module first. 34 35A "decoder" is a function with the signature: 36 Decode(buffer, pos, end, message, field_dict) 37The arguments are: 38 buffer: The string containing the encoded message. 39 pos: The current position in the string. 40 end: The position in the string where the current message ends. May be 41 less than len(buffer) if we're reading a sub-message. 42 message: The message object into which we're parsing. 43 field_dict: message._fields (avoids a hashtable lookup). 44The decoder reads the field and stores it into field_dict, returning the new 45buffer position. A decoder for a repeated field may proactively decode all of 46the elements of that field, if they appear consecutively. 47 48Note that decoders may throw any of the following: 49 IndexError: Indicates a truncated message. 50 struct.error: Unpacking of a fixed-width field failed. 51 message.DecodeError: Other errors. 52 53Decoders are expected to raise an exception if they are called with pos > end. 54This allows callers to be lax about bounds checking: it's fineto read past 55"end" as long as you are sure that someone else will notice and throw an 56exception later on. 57 58Something up the call stack is expected to catch IndexError and struct.error 59and convert them to message.DecodeError. 60 61Decoders are constructed using decoder constructors with the signature: 62 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 63The arguments are: 64 field_number: The field number of the field we want to decode. 65 is_repeated: Is the field a repeated field? (bool) 66 is_packed: Is the field a packed field? (bool) 67 key: The key to use when looking up the field within field_dict. 68 (This is actually the FieldDescriptor but nothing in this 69 file should depend on that.) 70 new_default: A function which takes a message object as a parameter and 71 returns a new instance of the default value for this field. 72 (This is called for repeated fields and sub-messages, when an 73 instance does not already exist.) 74 75As with encoders, we define a decoder constructor for every type of field. 76Then, for every field of every message class we construct an actual decoder. 77That decoder goes into a dict indexed by tag, so when we decode a message 78we repeatedly read a tag, look up the corresponding decoder, and invoke it. 79""" 80 81__author__ = 'kenton@google.com (Kenton Varda)' 82 83import struct 84 85import six 86 87if six.PY3: 88 long = int 89 90from google.protobuf.internal import encoder 91from google.protobuf.internal import wire_format 92from google.protobuf import message 93 94 95# This will overflow and thus become IEEE-754 "infinity". We would use 96# "float('inf')" but it doesn't work on Windows pre-Python-2.6. 97_POS_INF = 1e10000 98_NEG_INF = -_POS_INF 99_NAN = _POS_INF * 0 100 101 102# This is not for optimization, but rather to avoid conflicts with local 103# variables named "message". 104_DecodeError = message.DecodeError 105 106 107def _VarintDecoder(mask, result_type): 108 """Return an encoder for a basic varint value (does not include tag). 109 110 Decoded values will be bitwise-anded with the given mask before being 111 returned, e.g. to limit them to 32 bits. The returned decoder does not 112 take the usual "end" parameter -- the caller is expected to do bounds checking 113 after the fact (often the caller can defer such checking until later). The 114 decoder returns a (value, new_pos) pair. 115 """ 116 117 def DecodeVarint(buffer, pos): 118 result = 0 119 shift = 0 120 while 1: 121 b = six.indexbytes(buffer, pos) 122 result |= ((b & 0x7f) << shift) 123 pos += 1 124 if not (b & 0x80): 125 result &= mask 126 result = result_type(result) 127 return (result, pos) 128 shift += 7 129 if shift >= 64: 130 raise _DecodeError('Too many bytes when decoding varint.') 131 return DecodeVarint 132 133 134def _SignedVarintDecoder(mask, result_type): 135 """Like _VarintDecoder() but decodes signed values.""" 136 137 def DecodeVarint(buffer, pos): 138 result = 0 139 shift = 0 140 while 1: 141 b = six.indexbytes(buffer, pos) 142 result |= ((b & 0x7f) << shift) 143 pos += 1 144 if not (b & 0x80): 145 if result > 0x7fffffffffffffff: 146 result -= (1 << 64) 147 result |= ~mask 148 else: 149 result &= mask 150 result = result_type(result) 151 return (result, pos) 152 shift += 7 153 if shift >= 64: 154 raise _DecodeError('Too many bytes when decoding varint.') 155 return DecodeVarint 156 157# We force 32-bit values to int and 64-bit values to long to make 158# alternate implementations where the distinction is more significant 159# (e.g. the C++ implementation) simpler. 160 161_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) 162_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long) 163 164# Use these versions for values which must be limited to 32 bits. 165_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 166_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int) 167 168 169def ReadTag(buffer, pos): 170 """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple. 171 172 We return the raw bytes of the tag rather than decoding them. The raw 173 bytes can then be used to look up the proper decoder. This effectively allows 174 us to trade some work that would be done in pure-python (decoding a varint) 175 for work that is done in C (searching for a byte string in a hash table). 176 In a low-level language it would be much cheaper to decode the varint and 177 use that, but not in Python. 178 """ 179 180 start = pos 181 while six.indexbytes(buffer, pos) & 0x80: 182 pos += 1 183 pos += 1 184 return (buffer[start:pos], pos) 185 186 187# -------------------------------------------------------------------- 188 189 190def _SimpleDecoder(wire_type, decode_value): 191 """Return a constructor for a decoder for fields of a particular type. 192 193 Args: 194 wire_type: The field's wire type. 195 decode_value: A function which decodes an individual value, e.g. 196 _DecodeVarint() 197 """ 198 199 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default): 200 if is_packed: 201 local_DecodeVarint = _DecodeVarint 202 def DecodePackedField(buffer, pos, end, message, field_dict): 203 value = field_dict.get(key) 204 if value is None: 205 value = field_dict.setdefault(key, new_default(message)) 206 (endpoint, pos) = local_DecodeVarint(buffer, pos) 207 endpoint += pos 208 if endpoint > end: 209 raise _DecodeError('Truncated message.') 210 while pos < endpoint: 211 (element, pos) = decode_value(buffer, pos) 212 value.append(element) 213 if pos > endpoint: 214 del value[-1] # Discard corrupt value. 215 raise _DecodeError('Packed element was truncated.') 216 return pos 217 return DecodePackedField 218 elif is_repeated: 219 tag_bytes = encoder.TagBytes(field_number, wire_type) 220 tag_len = len(tag_bytes) 221 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 222 value = field_dict.get(key) 223 if value is None: 224 value = field_dict.setdefault(key, new_default(message)) 225 while 1: 226 (element, new_pos) = decode_value(buffer, pos) 227 value.append(element) 228 # Predict that the next tag is another copy of the same repeated 229 # field. 230 pos = new_pos + tag_len 231 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 232 # Prediction failed. Return. 233 if new_pos > end: 234 raise _DecodeError('Truncated message.') 235 return new_pos 236 return DecodeRepeatedField 237 else: 238 def DecodeField(buffer, pos, end, message, field_dict): 239 (field_dict[key], pos) = decode_value(buffer, pos) 240 if pos > end: 241 del field_dict[key] # Discard corrupt value. 242 raise _DecodeError('Truncated message.') 243 return pos 244 return DecodeField 245 246 return SpecificDecoder 247 248 249def _ModifiedDecoder(wire_type, decode_value, modify_value): 250 """Like SimpleDecoder but additionally invokes modify_value on every value 251 before storing it. Usually modify_value is ZigZagDecode. 252 """ 253 254 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 255 # not enough to make a significant difference. 256 257 def InnerDecode(buffer, pos): 258 (result, new_pos) = decode_value(buffer, pos) 259 return (modify_value(result), new_pos) 260 return _SimpleDecoder(wire_type, InnerDecode) 261 262 263def _StructPackDecoder(wire_type, format): 264 """Return a constructor for a decoder for a fixed-width field. 265 266 Args: 267 wire_type: The field's wire type. 268 format: The format string to pass to struct.unpack(). 269 """ 270 271 value_size = struct.calcsize(format) 272 local_unpack = struct.unpack 273 274 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 275 # not enough to make a significant difference. 276 277 # Note that we expect someone up-stack to catch struct.error and convert 278 # it to _DecodeError -- this way we don't have to set up exception- 279 # handling blocks every time we parse one value. 280 281 def InnerDecode(buffer, pos): 282 new_pos = pos + value_size 283 result = local_unpack(format, buffer[pos:new_pos])[0] 284 return (result, new_pos) 285 return _SimpleDecoder(wire_type, InnerDecode) 286 287 288def _FloatDecoder(): 289 """Returns a decoder for a float field. 290 291 This code works around a bug in struct.unpack for non-finite 32-bit 292 floating-point values. 293 """ 294 295 local_unpack = struct.unpack 296 297 def InnerDecode(buffer, pos): 298 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 299 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 300 new_pos = pos + 4 301 float_bytes = buffer[pos:new_pos] 302 303 # If this value has all its exponent bits set, then it's non-finite. 304 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 305 # To avoid that, we parse it specially. 306 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 307 # If at least one significand bit is set... 308 if float_bytes[0:3] != b'\x00\x00\x80': 309 return (_NAN, new_pos) 310 # If sign bit is set... 311 if float_bytes[3:4] == b'\xFF': 312 return (_NEG_INF, new_pos) 313 return (_POS_INF, new_pos) 314 315 # Note that we expect someone up-stack to catch struct.error and convert 316 # it to _DecodeError -- this way we don't have to set up exception- 317 # handling blocks every time we parse one value. 318 result = local_unpack('<f', float_bytes)[0] 319 return (result, new_pos) 320 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 321 322 323def _DoubleDecoder(): 324 """Returns a decoder for a double field. 325 326 This code works around a bug in struct.unpack for not-a-number. 327 """ 328 329 local_unpack = struct.unpack 330 331 def InnerDecode(buffer, pos): 332 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 333 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 334 new_pos = pos + 8 335 double_bytes = buffer[pos:new_pos] 336 337 # If this value has all its exponent bits set and at least one significand 338 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 339 # as inf or -inf. To avoid that, we treat it specially. 340 if ((double_bytes[7:8] in b'\x7F\xFF') 341 and (double_bytes[6:7] >= b'\xF0') 342 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 343 return (_NAN, new_pos) 344 345 # Note that we expect someone up-stack to catch struct.error and convert 346 # it to _DecodeError -- this way we don't have to set up exception- 347 # handling blocks every time we parse one value. 348 result = local_unpack('<d', double_bytes)[0] 349 return (result, new_pos) 350 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 351 352 353def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): 354 enum_type = key.enum_type 355 if is_packed: 356 local_DecodeVarint = _DecodeVarint 357 def DecodePackedField(buffer, pos, end, message, field_dict): 358 value = field_dict.get(key) 359 if value is None: 360 value = field_dict.setdefault(key, new_default(message)) 361 (endpoint, pos) = local_DecodeVarint(buffer, pos) 362 endpoint += pos 363 if endpoint > end: 364 raise _DecodeError('Truncated message.') 365 while pos < endpoint: 366 value_start_pos = pos 367 (element, pos) = _DecodeSignedVarint32(buffer, pos) 368 if element in enum_type.values_by_number: 369 value.append(element) 370 else: 371 if not message._unknown_fields: 372 message._unknown_fields = [] 373 tag_bytes = encoder.TagBytes(field_number, 374 wire_format.WIRETYPE_VARINT) 375 message._unknown_fields.append( 376 (tag_bytes, buffer[value_start_pos:pos])) 377 if pos > endpoint: 378 if element in enum_type.values_by_number: 379 del value[-1] # Discard corrupt value. 380 else: 381 del message._unknown_fields[-1] 382 raise _DecodeError('Packed element was truncated.') 383 return pos 384 return DecodePackedField 385 elif is_repeated: 386 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 387 tag_len = len(tag_bytes) 388 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 389 value = field_dict.get(key) 390 if value is None: 391 value = field_dict.setdefault(key, new_default(message)) 392 while 1: 393 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 394 if element in enum_type.values_by_number: 395 value.append(element) 396 else: 397 if not message._unknown_fields: 398 message._unknown_fields = [] 399 message._unknown_fields.append( 400 (tag_bytes, buffer[pos:new_pos])) 401 # Predict that the next tag is another copy of the same repeated 402 # field. 403 pos = new_pos + tag_len 404 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 405 # Prediction failed. Return. 406 if new_pos > end: 407 raise _DecodeError('Truncated message.') 408 return new_pos 409 return DecodeRepeatedField 410 else: 411 def DecodeField(buffer, pos, end, message, field_dict): 412 value_start_pos = pos 413 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 414 if pos > end: 415 raise _DecodeError('Truncated message.') 416 if enum_value in enum_type.values_by_number: 417 field_dict[key] = enum_value 418 else: 419 if not message._unknown_fields: 420 message._unknown_fields = [] 421 tag_bytes = encoder.TagBytes(field_number, 422 wire_format.WIRETYPE_VARINT) 423 message._unknown_fields.append( 424 (tag_bytes, buffer[value_start_pos:pos])) 425 return pos 426 return DecodeField 427 428 429# -------------------------------------------------------------------- 430 431 432Int32Decoder = _SimpleDecoder( 433 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 434 435Int64Decoder = _SimpleDecoder( 436 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 437 438UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 439UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 440 441SInt32Decoder = _ModifiedDecoder( 442 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 443SInt64Decoder = _ModifiedDecoder( 444 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 445 446# Note that Python conveniently guarantees that when using the '<' prefix on 447# formats, they will also have the same size across all platforms (as opposed 448# to without the prefix, where their sizes depend on the C compiler's basic 449# type sizes). 450Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 451Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 452SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 453SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 454FloatDecoder = _FloatDecoder() 455DoubleDecoder = _DoubleDecoder() 456 457BoolDecoder = _ModifiedDecoder( 458 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 459 460 461def StringDecoder(field_number, is_repeated, is_packed, key, new_default): 462 """Returns a decoder for a string field.""" 463 464 local_DecodeVarint = _DecodeVarint 465 local_unicode = six.text_type 466 467 def _ConvertToUnicode(byte_str): 468 try: 469 return local_unicode(byte_str, 'utf-8') 470 except UnicodeDecodeError as e: 471 # add more information to the error message and re-raise it. 472 e.reason = '%s in field: %s' % (e, key.full_name) 473 raise 474 475 assert not is_packed 476 if is_repeated: 477 tag_bytes = encoder.TagBytes(field_number, 478 wire_format.WIRETYPE_LENGTH_DELIMITED) 479 tag_len = len(tag_bytes) 480 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 481 value = field_dict.get(key) 482 if value is None: 483 value = field_dict.setdefault(key, new_default(message)) 484 while 1: 485 (size, pos) = local_DecodeVarint(buffer, pos) 486 new_pos = pos + size 487 if new_pos > end: 488 raise _DecodeError('Truncated string.') 489 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 490 # Predict that the next tag is another copy of the same repeated field. 491 pos = new_pos + tag_len 492 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 493 # Prediction failed. Return. 494 return new_pos 495 return DecodeRepeatedField 496 else: 497 def DecodeField(buffer, pos, end, message, field_dict): 498 (size, pos) = local_DecodeVarint(buffer, pos) 499 new_pos = pos + size 500 if new_pos > end: 501 raise _DecodeError('Truncated string.') 502 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 503 return new_pos 504 return DecodeField 505 506 507def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): 508 """Returns a decoder for a bytes field.""" 509 510 local_DecodeVarint = _DecodeVarint 511 512 assert not is_packed 513 if is_repeated: 514 tag_bytes = encoder.TagBytes(field_number, 515 wire_format.WIRETYPE_LENGTH_DELIMITED) 516 tag_len = len(tag_bytes) 517 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 518 value = field_dict.get(key) 519 if value is None: 520 value = field_dict.setdefault(key, new_default(message)) 521 while 1: 522 (size, pos) = local_DecodeVarint(buffer, pos) 523 new_pos = pos + size 524 if new_pos > end: 525 raise _DecodeError('Truncated string.') 526 value.append(buffer[pos:new_pos]) 527 # Predict that the next tag is another copy of the same repeated field. 528 pos = new_pos + tag_len 529 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 530 # Prediction failed. Return. 531 return new_pos 532 return DecodeRepeatedField 533 else: 534 def DecodeField(buffer, pos, end, message, field_dict): 535 (size, pos) = local_DecodeVarint(buffer, pos) 536 new_pos = pos + size 537 if new_pos > end: 538 raise _DecodeError('Truncated string.') 539 field_dict[key] = buffer[pos:new_pos] 540 return new_pos 541 return DecodeField 542 543 544def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 545 """Returns a decoder for a group field.""" 546 547 end_tag_bytes = encoder.TagBytes(field_number, 548 wire_format.WIRETYPE_END_GROUP) 549 end_tag_len = len(end_tag_bytes) 550 551 assert not is_packed 552 if is_repeated: 553 tag_bytes = encoder.TagBytes(field_number, 554 wire_format.WIRETYPE_START_GROUP) 555 tag_len = len(tag_bytes) 556 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 557 value = field_dict.get(key) 558 if value is None: 559 value = field_dict.setdefault(key, new_default(message)) 560 while 1: 561 value = field_dict.get(key) 562 if value is None: 563 value = field_dict.setdefault(key, new_default(message)) 564 # Read sub-message. 565 pos = value.add()._InternalParse(buffer, pos, end) 566 # Read end tag. 567 new_pos = pos+end_tag_len 568 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 569 raise _DecodeError('Missing group end tag.') 570 # Predict that the next tag is another copy of the same repeated field. 571 pos = new_pos + tag_len 572 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 573 # Prediction failed. Return. 574 return new_pos 575 return DecodeRepeatedField 576 else: 577 def DecodeField(buffer, pos, end, message, field_dict): 578 value = field_dict.get(key) 579 if value is None: 580 value = field_dict.setdefault(key, new_default(message)) 581 # Read sub-message. 582 pos = value._InternalParse(buffer, pos, end) 583 # Read end tag. 584 new_pos = pos+end_tag_len 585 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 586 raise _DecodeError('Missing group end tag.') 587 return new_pos 588 return DecodeField 589 590 591def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 592 """Returns a decoder for a message field.""" 593 594 local_DecodeVarint = _DecodeVarint 595 596 assert not is_packed 597 if is_repeated: 598 tag_bytes = encoder.TagBytes(field_number, 599 wire_format.WIRETYPE_LENGTH_DELIMITED) 600 tag_len = len(tag_bytes) 601 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 602 value = field_dict.get(key) 603 if value is None: 604 value = field_dict.setdefault(key, new_default(message)) 605 while 1: 606 # Read length. 607 (size, pos) = local_DecodeVarint(buffer, pos) 608 new_pos = pos + size 609 if new_pos > end: 610 raise _DecodeError('Truncated message.') 611 # Read sub-message. 612 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 613 # The only reason _InternalParse would return early is if it 614 # encountered an end-group tag. 615 raise _DecodeError('Unexpected end-group tag.') 616 # Predict that the next tag is another copy of the same repeated field. 617 pos = new_pos + tag_len 618 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 619 # Prediction failed. Return. 620 return new_pos 621 return DecodeRepeatedField 622 else: 623 def DecodeField(buffer, pos, end, message, field_dict): 624 value = field_dict.get(key) 625 if value is None: 626 value = field_dict.setdefault(key, new_default(message)) 627 # Read length. 628 (size, pos) = local_DecodeVarint(buffer, pos) 629 new_pos = pos + size 630 if new_pos > end: 631 raise _DecodeError('Truncated message.') 632 # Read sub-message. 633 if value._InternalParse(buffer, pos, new_pos) != new_pos: 634 # The only reason _InternalParse would return early is if it encountered 635 # an end-group tag. 636 raise _DecodeError('Unexpected end-group tag.') 637 return new_pos 638 return DecodeField 639 640 641# -------------------------------------------------------------------- 642 643MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 644 645def MessageSetItemDecoder(extensions_by_number): 646 """Returns a decoder for a MessageSet item. 647 648 The parameter is the _extensions_by_number map for the message class. 649 650 The message set message looks like this: 651 message MessageSet { 652 repeated group Item = 1 { 653 required int32 type_id = 2; 654 required string message = 3; 655 } 656 } 657 """ 658 659 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 660 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 661 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 662 663 local_ReadTag = ReadTag 664 local_DecodeVarint = _DecodeVarint 665 local_SkipField = SkipField 666 667 def DecodeItem(buffer, pos, end, message, field_dict): 668 message_set_item_start = pos 669 type_id = -1 670 message_start = -1 671 message_end = -1 672 673 # Technically, type_id and message can appear in any order, so we need 674 # a little loop here. 675 while 1: 676 (tag_bytes, pos) = local_ReadTag(buffer, pos) 677 if tag_bytes == type_id_tag_bytes: 678 (type_id, pos) = local_DecodeVarint(buffer, pos) 679 elif tag_bytes == message_tag_bytes: 680 (size, message_start) = local_DecodeVarint(buffer, pos) 681 pos = message_end = message_start + size 682 elif tag_bytes == item_end_tag_bytes: 683 break 684 else: 685 pos = SkipField(buffer, pos, end, tag_bytes) 686 if pos == -1: 687 raise _DecodeError('Missing group end tag.') 688 689 if pos > end: 690 raise _DecodeError('Truncated message.') 691 692 if type_id == -1: 693 raise _DecodeError('MessageSet item missing type_id.') 694 if message_start == -1: 695 raise _DecodeError('MessageSet item missing message.') 696 697 extension = extensions_by_number.get(type_id) 698 if extension is not None: 699 value = field_dict.get(extension) 700 if value is None: 701 value = field_dict.setdefault( 702 extension, extension.message_type._concrete_class()) 703 if value._InternalParse(buffer, message_start,message_end) != message_end: 704 # The only reason _InternalParse would return early is if it encountered 705 # an end-group tag. 706 raise _DecodeError('Unexpected end-group tag.') 707 else: 708 if not message._unknown_fields: 709 message._unknown_fields = [] 710 message._unknown_fields.append((MESSAGE_SET_ITEM_TAG, 711 buffer[message_set_item_start:pos])) 712 713 return pos 714 715 return DecodeItem 716 717# -------------------------------------------------------------------- 718 719def MapDecoder(field_descriptor, new_default, is_message_map): 720 """Returns a decoder for a map field.""" 721 722 key = field_descriptor 723 tag_bytes = encoder.TagBytes(field_descriptor.number, 724 wire_format.WIRETYPE_LENGTH_DELIMITED) 725 tag_len = len(tag_bytes) 726 local_DecodeVarint = _DecodeVarint 727 # Can't read _concrete_class yet; might not be initialized. 728 message_type = field_descriptor.message_type 729 730 def DecodeMap(buffer, pos, end, message, field_dict): 731 submsg = message_type._concrete_class() 732 value = field_dict.get(key) 733 if value is None: 734 value = field_dict.setdefault(key, new_default(message)) 735 while 1: 736 # Read length. 737 (size, pos) = local_DecodeVarint(buffer, pos) 738 new_pos = pos + size 739 if new_pos > end: 740 raise _DecodeError('Truncated message.') 741 # Read sub-message. 742 submsg.Clear() 743 if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 744 # The only reason _InternalParse would return early is if it 745 # encountered an end-group tag. 746 raise _DecodeError('Unexpected end-group tag.') 747 748 if is_message_map: 749 value[submsg.key].MergeFrom(submsg.value) 750 else: 751 value[submsg.key] = submsg.value 752 753 # Predict that the next tag is another copy of the same repeated field. 754 pos = new_pos + tag_len 755 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 756 # Prediction failed. Return. 757 return new_pos 758 759 return DecodeMap 760 761# -------------------------------------------------------------------- 762# Optimization is not as heavy here because calls to SkipField() are rare, 763# except for handling end-group tags. 764 765def _SkipVarint(buffer, pos, end): 766 """Skip a varint value. Returns the new position.""" 767 # Previously ord(buffer[pos]) raised IndexError when pos is out of range. 768 # With this code, ord(b'') raises TypeError. Both are handled in 769 # python_message.py to generate a 'Truncated message' error. 770 while ord(buffer[pos:pos+1]) & 0x80: 771 pos += 1 772 pos += 1 773 if pos > end: 774 raise _DecodeError('Truncated message.') 775 return pos 776 777def _SkipFixed64(buffer, pos, end): 778 """Skip a fixed64 value. Returns the new position.""" 779 780 pos += 8 781 if pos > end: 782 raise _DecodeError('Truncated message.') 783 return pos 784 785def _SkipLengthDelimited(buffer, pos, end): 786 """Skip a length-delimited value. Returns the new position.""" 787 788 (size, pos) = _DecodeVarint(buffer, pos) 789 pos += size 790 if pos > end: 791 raise _DecodeError('Truncated message.') 792 return pos 793 794def _SkipGroup(buffer, pos, end): 795 """Skip sub-group. Returns the new position.""" 796 797 while 1: 798 (tag_bytes, pos) = ReadTag(buffer, pos) 799 new_pos = SkipField(buffer, pos, end, tag_bytes) 800 if new_pos == -1: 801 return pos 802 pos = new_pos 803 804def _EndGroup(buffer, pos, end): 805 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" 806 807 return -1 808 809def _SkipFixed32(buffer, pos, end): 810 """Skip a fixed32 value. Returns the new position.""" 811 812 pos += 4 813 if pos > end: 814 raise _DecodeError('Truncated message.') 815 return pos 816 817def _RaiseInvalidWireType(buffer, pos, end): 818 """Skip function for unknown wire types. Raises an exception.""" 819 820 raise _DecodeError('Tag had invalid wire type.') 821 822def _FieldSkipper(): 823 """Constructs the SkipField function.""" 824 825 WIRETYPE_TO_SKIPPER = [ 826 _SkipVarint, 827 _SkipFixed64, 828 _SkipLengthDelimited, 829 _SkipGroup, 830 _EndGroup, 831 _SkipFixed32, 832 _RaiseInvalidWireType, 833 _RaiseInvalidWireType, 834 ] 835 836 wiretype_mask = wire_format.TAG_TYPE_MASK 837 838 def SkipField(buffer, pos, end, tag_bytes): 839 """Skips a field with the specified tag. 840 841 |pos| should point to the byte immediately after the tag. 842 843 Returns: 844 The new position (after the tag value), or -1 if the tag is an end-group 845 tag (in which case the calling loop should break). 846 """ 847 848 # The wire type is always in the first byte since varints are little-endian. 849 wire_type = ord(tag_bytes[0:1]) & wiretype_mask 850 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 851 852 return SkipField 853 854SkipField = _FieldSkipper() 855