1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# http://code.google.com/p/protobuf/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Code for decoding protocol buffer primitives. 32 33This code is very similar to encoder.py -- read the docs for that module first. 34 35A "decoder" is a function with the signature: 36 Decode(buffer, pos, end, message, field_dict) 37The arguments are: 38 buffer: The string containing the encoded message. 39 pos: The current position in the string. 40 end: The position in the string where the current message ends. May be 41 less than len(buffer) if we're reading a sub-message. 42 message: The message object into which we're parsing. 43 field_dict: message._fields (avoids a hashtable lookup). 44The decoder reads the field and stores it into field_dict, returning the new 45buffer position. A decoder for a repeated field may proactively decode all of 46the elements of that field, if they appear consecutively. 47 48Note that decoders may throw any of the following: 49 IndexError: Indicates a truncated message. 50 struct.error: Unpacking of a fixed-width field failed. 51 message.DecodeError: Other errors. 52 53Decoders are expected to raise an exception if they are called with pos > end. 54This allows callers to be lax about bounds checking: it's fineto read past 55"end" as long as you are sure that someone else will notice and throw an 56exception later on. 57 58Something up the call stack is expected to catch IndexError and struct.error 59and convert them to message.DecodeError. 60 61Decoders are constructed using decoder constructors with the signature: 62 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 63The arguments are: 64 field_number: The field number of the field we want to decode. 65 is_repeated: Is the field a repeated field? (bool) 66 is_packed: Is the field a packed field? (bool) 67 key: The key to use when looking up the field within field_dict. 68 (This is actually the FieldDescriptor but nothing in this 69 file should depend on that.) 70 new_default: A function which takes a message object as a parameter and 71 returns a new instance of the default value for this field. 72 (This is called for repeated fields and sub-messages, when an 73 instance does not already exist.) 74 75As with encoders, we define a decoder constructor for every type of field. 76Then, for every field of every message class we construct an actual decoder. 77That decoder goes into a dict indexed by tag, so when we decode a message 78we repeatedly read a tag, look up the corresponding decoder, and invoke it. 79""" 80 81__author__ = 'kenton@google.com (Kenton Varda)' 82 83import struct 84from google.protobuf.internal import encoder 85from google.protobuf.internal import wire_format 86from google.protobuf import message 87 88 89# This is not for optimization, but rather to avoid conflicts with local 90# variables named "message". 91_DecodeError = message.DecodeError 92 93 94def _VarintDecoder(mask): 95 """Return an encoder for a basic varint value (does not include tag). 96 97 Decoded values will be bitwise-anded with the given mask before being 98 returned, e.g. to limit them to 32 bits. The returned decoder does not 99 take the usual "end" parameter -- the caller is expected to do bounds checking 100 after the fact (often the caller can defer such checking until later). The 101 decoder returns a (value, new_pos) pair. 102 """ 103 104 local_ord = ord 105 def DecodeVarint(buffer, pos): 106 result = 0 107 shift = 0 108 while 1: 109 b = local_ord(buffer[pos]) 110 result |= ((b & 0x7f) << shift) 111 pos += 1 112 if not (b & 0x80): 113 result &= mask 114 return (result, pos) 115 shift += 7 116 if shift >= 64: 117 raise _DecodeError('Too many bytes when decoding varint.') 118 return DecodeVarint 119 120 121def _SignedVarintDecoder(mask): 122 """Like _VarintDecoder() but decodes signed values.""" 123 124 local_ord = ord 125 def DecodeVarint(buffer, pos): 126 result = 0 127 shift = 0 128 while 1: 129 b = local_ord(buffer[pos]) 130 result |= ((b & 0x7f) << shift) 131 pos += 1 132 if not (b & 0x80): 133 if result > 0x7fffffffffffffff: 134 result -= (1 << 64) 135 result |= ~mask 136 else: 137 result &= mask 138 return (result, pos) 139 shift += 7 140 if shift >= 64: 141 raise _DecodeError('Too many bytes when decoding varint.') 142 return DecodeVarint 143 144 145_DecodeVarint = _VarintDecoder((1 << 64) - 1) 146_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1) 147 148# Use these versions for values which must be limited to 32 bits. 149_DecodeVarint32 = _VarintDecoder((1 << 32) - 1) 150_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1) 151 152 153def ReadTag(buffer, pos): 154 """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple. 155 156 We return the raw bytes of the tag rather than decoding them. The raw 157 bytes can then be used to look up the proper decoder. This effectively allows 158 us to trade some work that would be done in pure-python (decoding a varint) 159 for work that is done in C (searching for a byte string in a hash table). 160 In a low-level language it would be much cheaper to decode the varint and 161 use that, but not in Python. 162 """ 163 164 start = pos 165 while ord(buffer[pos]) & 0x80: 166 pos += 1 167 pos += 1 168 return (buffer[start:pos], pos) 169 170 171# -------------------------------------------------------------------- 172 173 174def _SimpleDecoder(wire_type, decode_value): 175 """Return a constructor for a decoder for fields of a particular type. 176 177 Args: 178 wire_type: The field's wire type. 179 decode_value: A function which decodes an individual value, e.g. 180 _DecodeVarint() 181 """ 182 183 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default): 184 if is_packed: 185 local_DecodeVarint = _DecodeVarint 186 def DecodePackedField(buffer, pos, end, message, field_dict): 187 value = field_dict.get(key) 188 if value is None: 189 value = field_dict.setdefault(key, new_default(message)) 190 (endpoint, pos) = local_DecodeVarint(buffer, pos) 191 endpoint += pos 192 if endpoint > end: 193 raise _DecodeError('Truncated message.') 194 while pos < endpoint: 195 (element, pos) = decode_value(buffer, pos) 196 value.append(element) 197 if pos > endpoint: 198 del value[-1] # Discard corrupt value. 199 raise _DecodeError('Packed element was truncated.') 200 return pos 201 return DecodePackedField 202 elif is_repeated: 203 tag_bytes = encoder.TagBytes(field_number, wire_type) 204 tag_len = len(tag_bytes) 205 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 206 value = field_dict.get(key) 207 if value is None: 208 value = field_dict.setdefault(key, new_default(message)) 209 while 1: 210 (element, new_pos) = decode_value(buffer, pos) 211 value.append(element) 212 # Predict that the next tag is another copy of the same repeated 213 # field. 214 pos = new_pos + tag_len 215 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 216 # Prediction failed. Return. 217 if new_pos > end: 218 raise _DecodeError('Truncated message.') 219 return new_pos 220 return DecodeRepeatedField 221 else: 222 def DecodeField(buffer, pos, end, message, field_dict): 223 (field_dict[key], pos) = decode_value(buffer, pos) 224 if pos > end: 225 del field_dict[key] # Discard corrupt value. 226 raise _DecodeError('Truncated message.') 227 return pos 228 return DecodeField 229 230 return SpecificDecoder 231 232 233def _ModifiedDecoder(wire_type, decode_value, modify_value): 234 """Like SimpleDecoder but additionally invokes modify_value on every value 235 before storing it. Usually modify_value is ZigZagDecode. 236 """ 237 238 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 239 # not enough to make a significant difference. 240 241 def InnerDecode(buffer, pos): 242 (result, new_pos) = decode_value(buffer, pos) 243 return (modify_value(result), new_pos) 244 return _SimpleDecoder(wire_type, InnerDecode) 245 246 247def _StructPackDecoder(wire_type, format): 248 """Return a constructor for a decoder for a fixed-width field. 249 250 Args: 251 wire_type: The field's wire type. 252 format: The format string to pass to struct.unpack(). 253 """ 254 255 value_size = struct.calcsize(format) 256 local_unpack = struct.unpack 257 258 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 259 # not enough to make a significant difference. 260 261 # Note that we expect someone up-stack to catch struct.error and convert 262 # it to _DecodeError -- this way we don't have to set up exception- 263 # handling blocks every time we parse one value. 264 265 def InnerDecode(buffer, pos): 266 new_pos = pos + value_size 267 result = local_unpack(format, buffer[pos:new_pos])[0] 268 return (result, new_pos) 269 return _SimpleDecoder(wire_type, InnerDecode) 270 271 272# -------------------------------------------------------------------- 273 274 275Int32Decoder = EnumDecoder = _SimpleDecoder( 276 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 277 278Int64Decoder = _SimpleDecoder( 279 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 280 281UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 282UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 283 284SInt32Decoder = _ModifiedDecoder( 285 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 286SInt64Decoder = _ModifiedDecoder( 287 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 288 289# Note that Python conveniently guarantees that when using the '<' prefix on 290# formats, they will also have the same size across all platforms (as opposed 291# to without the prefix, where their sizes depend on the C compiler's basic 292# type sizes). 293Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 294Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 295SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 296SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 297FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<f') 298DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<d') 299 300BoolDecoder = _ModifiedDecoder( 301 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 302 303 304def StringDecoder(field_number, is_repeated, is_packed, key, new_default): 305 """Returns a decoder for a string field.""" 306 307 local_DecodeVarint = _DecodeVarint 308 local_unicode = unicode 309 310 assert not is_packed 311 if is_repeated: 312 tag_bytes = encoder.TagBytes(field_number, 313 wire_format.WIRETYPE_LENGTH_DELIMITED) 314 tag_len = len(tag_bytes) 315 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 316 value = field_dict.get(key) 317 if value is None: 318 value = field_dict.setdefault(key, new_default(message)) 319 while 1: 320 (size, pos) = local_DecodeVarint(buffer, pos) 321 new_pos = pos + size 322 if new_pos > end: 323 raise _DecodeError('Truncated string.') 324 value.append(local_unicode(buffer[pos:new_pos], 'utf-8')) 325 # Predict that the next tag is another copy of the same repeated field. 326 pos = new_pos + tag_len 327 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 328 # Prediction failed. Return. 329 return new_pos 330 return DecodeRepeatedField 331 else: 332 def DecodeField(buffer, pos, end, message, field_dict): 333 (size, pos) = local_DecodeVarint(buffer, pos) 334 new_pos = pos + size 335 if new_pos > end: 336 raise _DecodeError('Truncated string.') 337 field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8') 338 return new_pos 339 return DecodeField 340 341 342def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): 343 """Returns a decoder for a bytes field.""" 344 345 local_DecodeVarint = _DecodeVarint 346 347 assert not is_packed 348 if is_repeated: 349 tag_bytes = encoder.TagBytes(field_number, 350 wire_format.WIRETYPE_LENGTH_DELIMITED) 351 tag_len = len(tag_bytes) 352 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 353 value = field_dict.get(key) 354 if value is None: 355 value = field_dict.setdefault(key, new_default(message)) 356 while 1: 357 (size, pos) = local_DecodeVarint(buffer, pos) 358 new_pos = pos + size 359 if new_pos > end: 360 raise _DecodeError('Truncated string.') 361 value.append(buffer[pos:new_pos]) 362 # Predict that the next tag is another copy of the same repeated field. 363 pos = new_pos + tag_len 364 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 365 # Prediction failed. Return. 366 return new_pos 367 return DecodeRepeatedField 368 else: 369 def DecodeField(buffer, pos, end, message, field_dict): 370 (size, pos) = local_DecodeVarint(buffer, pos) 371 new_pos = pos + size 372 if new_pos > end: 373 raise _DecodeError('Truncated string.') 374 field_dict[key] = buffer[pos:new_pos] 375 return new_pos 376 return DecodeField 377 378 379def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 380 """Returns a decoder for a group field.""" 381 382 end_tag_bytes = encoder.TagBytes(field_number, 383 wire_format.WIRETYPE_END_GROUP) 384 end_tag_len = len(end_tag_bytes) 385 386 assert not is_packed 387 if is_repeated: 388 tag_bytes = encoder.TagBytes(field_number, 389 wire_format.WIRETYPE_START_GROUP) 390 tag_len = len(tag_bytes) 391 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 392 value = field_dict.get(key) 393 if value is None: 394 value = field_dict.setdefault(key, new_default(message)) 395 while 1: 396 value = field_dict.get(key) 397 if value is None: 398 value = field_dict.setdefault(key, new_default(message)) 399 # Read sub-message. 400 pos = value.add()._InternalParse(buffer, pos, end) 401 # Read end tag. 402 new_pos = pos+end_tag_len 403 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 404 raise _DecodeError('Missing group end tag.') 405 # Predict that the next tag is another copy of the same repeated field. 406 pos = new_pos + tag_len 407 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 408 # Prediction failed. Return. 409 return new_pos 410 return DecodeRepeatedField 411 else: 412 def DecodeField(buffer, pos, end, message, field_dict): 413 value = field_dict.get(key) 414 if value is None: 415 value = field_dict.setdefault(key, new_default(message)) 416 # Read sub-message. 417 pos = value._InternalParse(buffer, pos, end) 418 # Read end tag. 419 new_pos = pos+end_tag_len 420 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 421 raise _DecodeError('Missing group end tag.') 422 return new_pos 423 return DecodeField 424 425 426def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 427 """Returns a decoder for a message field.""" 428 429 local_DecodeVarint = _DecodeVarint 430 431 assert not is_packed 432 if is_repeated: 433 tag_bytes = encoder.TagBytes(field_number, 434 wire_format.WIRETYPE_LENGTH_DELIMITED) 435 tag_len = len(tag_bytes) 436 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 437 value = field_dict.get(key) 438 if value is None: 439 value = field_dict.setdefault(key, new_default(message)) 440 while 1: 441 value = field_dict.get(key) 442 if value is None: 443 value = field_dict.setdefault(key, new_default(message)) 444 # Read length. 445 (size, pos) = local_DecodeVarint(buffer, pos) 446 new_pos = pos + size 447 if new_pos > end: 448 raise _DecodeError('Truncated message.') 449 # Read sub-message. 450 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 451 # The only reason _InternalParse would return early is if it 452 # encountered an end-group tag. 453 raise _DecodeError('Unexpected end-group tag.') 454 # Predict that the next tag is another copy of the same repeated field. 455 pos = new_pos + tag_len 456 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 457 # Prediction failed. Return. 458 return new_pos 459 return DecodeRepeatedField 460 else: 461 def DecodeField(buffer, pos, end, message, field_dict): 462 value = field_dict.get(key) 463 if value is None: 464 value = field_dict.setdefault(key, new_default(message)) 465 # Read length. 466 (size, pos) = local_DecodeVarint(buffer, pos) 467 new_pos = pos + size 468 if new_pos > end: 469 raise _DecodeError('Truncated message.') 470 # Read sub-message. 471 if value._InternalParse(buffer, pos, new_pos) != new_pos: 472 # The only reason _InternalParse would return early is if it encountered 473 # an end-group tag. 474 raise _DecodeError('Unexpected end-group tag.') 475 return new_pos 476 return DecodeField 477 478 479# -------------------------------------------------------------------- 480 481MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 482 483def MessageSetItemDecoder(extensions_by_number): 484 """Returns a decoder for a MessageSet item. 485 486 The parameter is the _extensions_by_number map for the message class. 487 488 The message set message looks like this: 489 message MessageSet { 490 repeated group Item = 1 { 491 required int32 type_id = 2; 492 required string message = 3; 493 } 494 } 495 """ 496 497 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 498 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 499 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 500 501 local_ReadTag = ReadTag 502 local_DecodeVarint = _DecodeVarint 503 local_SkipField = SkipField 504 505 def DecodeItem(buffer, pos, end, message, field_dict): 506 type_id = -1 507 message_start = -1 508 message_end = -1 509 510 # Technically, type_id and message can appear in any order, so we need 511 # a little loop here. 512 while 1: 513 (tag_bytes, pos) = local_ReadTag(buffer, pos) 514 if tag_bytes == type_id_tag_bytes: 515 (type_id, pos) = local_DecodeVarint(buffer, pos) 516 elif tag_bytes == message_tag_bytes: 517 (size, message_start) = local_DecodeVarint(buffer, pos) 518 pos = message_end = message_start + size 519 elif tag_bytes == item_end_tag_bytes: 520 break 521 else: 522 pos = SkipField(buffer, pos, end, tag_bytes) 523 if pos == -1: 524 raise _DecodeError('Missing group end tag.') 525 526 if pos > end: 527 raise _DecodeError('Truncated message.') 528 529 if type_id == -1: 530 raise _DecodeError('MessageSet item missing type_id.') 531 if message_start == -1: 532 raise _DecodeError('MessageSet item missing message.') 533 534 extension = extensions_by_number.get(type_id) 535 if extension is not None: 536 value = field_dict.get(extension) 537 if value is None: 538 value = field_dict.setdefault( 539 extension, extension.message_type._concrete_class()) 540 if value._InternalParse(buffer, message_start,message_end) != message_end: 541 # The only reason _InternalParse would return early is if it encountered 542 # an end-group tag. 543 raise _DecodeError('Unexpected end-group tag.') 544 545 return pos 546 547 return DecodeItem 548 549# -------------------------------------------------------------------- 550# Optimization is not as heavy here because calls to SkipField() are rare, 551# except for handling end-group tags. 552 553def _SkipVarint(buffer, pos, end): 554 """Skip a varint value. Returns the new position.""" 555 556 while ord(buffer[pos]) & 0x80: 557 pos += 1 558 pos += 1 559 if pos > end: 560 raise _DecodeError('Truncated message.') 561 return pos 562 563def _SkipFixed64(buffer, pos, end): 564 """Skip a fixed64 value. Returns the new position.""" 565 566 pos += 8 567 if pos > end: 568 raise _DecodeError('Truncated message.') 569 return pos 570 571def _SkipLengthDelimited(buffer, pos, end): 572 """Skip a length-delimited value. Returns the new position.""" 573 574 (size, pos) = _DecodeVarint(buffer, pos) 575 pos += size 576 if pos > end: 577 raise _DecodeError('Truncated message.') 578 return pos 579 580def _SkipGroup(buffer, pos, end): 581 """Skip sub-group. Returns the new position.""" 582 583 while 1: 584 (tag_bytes, pos) = ReadTag(buffer, pos) 585 new_pos = SkipField(buffer, pos, end, tag_bytes) 586 if new_pos == -1: 587 return pos 588 pos = new_pos 589 590def _EndGroup(buffer, pos, end): 591 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" 592 593 return -1 594 595def _SkipFixed32(buffer, pos, end): 596 """Skip a fixed32 value. Returns the new position.""" 597 598 pos += 4 599 if pos > end: 600 raise _DecodeError('Truncated message.') 601 return pos 602 603def _RaiseInvalidWireType(buffer, pos, end): 604 """Skip function for unknown wire types. Raises an exception.""" 605 606 raise _DecodeError('Tag had invalid wire type.') 607 608def _FieldSkipper(): 609 """Constructs the SkipField function.""" 610 611 WIRETYPE_TO_SKIPPER = [ 612 _SkipVarint, 613 _SkipFixed64, 614 _SkipLengthDelimited, 615 _SkipGroup, 616 _EndGroup, 617 _SkipFixed32, 618 _RaiseInvalidWireType, 619 _RaiseInvalidWireType, 620 ] 621 622 wiretype_mask = wire_format.TAG_TYPE_MASK 623 local_ord = ord 624 625 def SkipField(buffer, pos, end, tag_bytes): 626 """Skips a field with the specified tag. 627 628 |pos| should point to the byte immediately after the tag. 629 630 Returns: 631 The new position (after the tag value), or -1 if the tag is an end-group 632 tag (in which case the calling loop should break). 633 """ 634 635 # The wire type is always in the first byte since varints are little-endian. 636 wire_type = local_ord(tag_bytes[0]) & wiretype_mask 637 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 638 639 return SkipField 640 641SkipField = _FieldSkipper() 642