• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Code for decoding protocol buffer primitives.
9
10This code is very similar to encoder.py -- read the docs for that module first.
11
12A "decoder" is a function with the signature:
13  Decode(buffer, pos, end, message, field_dict)
14The arguments are:
15  buffer:     The string containing the encoded message.
16  pos:        The current position in the string.
17  end:        The position in the string where the current message ends.  May be
18              less than len(buffer) if we're reading a sub-message.
19  message:    The message object into which we're parsing.
20  field_dict: message._fields (avoids a hashtable lookup).
21The decoder reads the field and stores it into field_dict, returning the new
22buffer position.  A decoder for a repeated field may proactively decode all of
23the elements of that field, if they appear consecutively.
24
25Note that decoders may throw any of the following:
26  IndexError:  Indicates a truncated message.
27  struct.error:  Unpacking of a fixed-width field failed.
28  message.DecodeError:  Other errors.
29
30Decoders are expected to raise an exception if they are called with pos > end.
31This allows callers to be lax about bounds checking:  it's fineto read past
32"end" as long as you are sure that someone else will notice and throw an
33exception later on.
34
35Something up the call stack is expected to catch IndexError and struct.error
36and convert them to message.DecodeError.
37
38Decoders are constructed using decoder constructors with the signature:
39  MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
40The arguments are:
41  field_number:  The field number of the field we want to decode.
42  is_repeated:   Is the field a repeated field? (bool)
43  is_packed:     Is the field a packed field? (bool)
44  key:           The key to use when looking up the field within field_dict.
45                 (This is actually the FieldDescriptor but nothing in this
46                 file should depend on that.)
47  new_default:   A function which takes a message object as a parameter and
48                 returns a new instance of the default value for this field.
49                 (This is called for repeated fields and sub-messages, when an
50                 instance does not already exist.)
51
52As with encoders, we define a decoder constructor for every type of field.
53Then, for every field of every message class we construct an actual decoder.
54That decoder goes into a dict indexed by tag, so when we decode a message
55we repeatedly read a tag, look up the corresponding decoder, and invoke it.
56"""
57
58__author__ = 'kenton@google.com (Kenton Varda)'
59
60import math
61import struct
62
63from google.protobuf import message
64from google.protobuf.internal import containers
65from google.protobuf.internal import encoder
66from google.protobuf.internal import wire_format
67
68
69# This is not for optimization, but rather to avoid conflicts with local
70# variables named "message".
71_DecodeError = message.DecodeError
72
73
74def _VarintDecoder(mask, result_type):
75  """Return an encoder for a basic varint value (does not include tag).
76
77  Decoded values will be bitwise-anded with the given mask before being
78  returned, e.g. to limit them to 32 bits.  The returned decoder does not
79  take the usual "end" parameter -- the caller is expected to do bounds checking
80  after the fact (often the caller can defer such checking until later).  The
81  decoder returns a (value, new_pos) pair.
82  """
83
84  def DecodeVarint(buffer, pos: int=None):
85    result = 0
86    shift = 0
87    while 1:
88      if pos is None:
89        # Read from BytesIO
90        try:
91          b = buffer.read(1)[0]
92        except IndexError as e:
93          if shift == 0:
94            # End of BytesIO.
95            return None
96          else:
97            raise ValueError('Fail to read varint %s' % str(e))
98      else:
99        b = buffer[pos]
100        pos += 1
101      result |= ((b & 0x7f) << shift)
102      if not (b & 0x80):
103        result &= mask
104        result = result_type(result)
105        return result if pos is None else (result, pos)
106      shift += 7
107      if shift >= 64:
108        raise _DecodeError('Too many bytes when decoding varint.')
109
110  return DecodeVarint
111
112
113def _SignedVarintDecoder(bits, result_type):
114  """Like _VarintDecoder() but decodes signed values."""
115
116  signbit = 1 << (bits - 1)
117  mask = (1 << bits) - 1
118
119  def DecodeVarint(buffer, pos):
120    result = 0
121    shift = 0
122    while 1:
123      b = buffer[pos]
124      result |= ((b & 0x7f) << shift)
125      pos += 1
126      if not (b & 0x80):
127        result &= mask
128        result = (result ^ signbit) - signbit
129        result = result_type(result)
130        return (result, pos)
131      shift += 7
132      if shift >= 64:
133        raise _DecodeError('Too many bytes when decoding varint.')
134  return DecodeVarint
135
136# All 32-bit and 64-bit values are represented as int.
137_DecodeVarint = _VarintDecoder((1 << 64) - 1, int)
138_DecodeSignedVarint = _SignedVarintDecoder(64, int)
139
140# Use these versions for values which must be limited to 32 bits.
141_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
142_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
143
144
145def ReadTag(buffer, pos):
146  """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
147
148  We return the raw bytes of the tag rather than decoding them.  The raw
149  bytes can then be used to look up the proper decoder.  This effectively allows
150  us to trade some work that would be done in pure-python (decoding a varint)
151  for work that is done in C (searching for a byte string in a hash table).
152  In a low-level language it would be much cheaper to decode the varint and
153  use that, but not in Python.
154
155  Args:
156    buffer: memoryview object of the encoded bytes
157    pos: int of the current position to start from
158
159  Returns:
160    Tuple[bytes, int] of the tag data and new position.
161  """
162  start = pos
163  while buffer[pos] & 0x80:
164    pos += 1
165  pos += 1
166
167  tag_bytes = buffer[start:pos].tobytes()
168  return tag_bytes, pos
169
170
171# --------------------------------------------------------------------
172
173
174def _SimpleDecoder(wire_type, decode_value):
175  """Return a constructor for a decoder for fields of a particular type.
176
177  Args:
178      wire_type:  The field's wire type.
179      decode_value:  A function which decodes an individual value, e.g.
180        _DecodeVarint()
181  """
182
183  def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
184                      clear_if_default=False):
185    if is_packed:
186      local_DecodeVarint = _DecodeVarint
187      def DecodePackedField(buffer, pos, end, message, field_dict):
188        value = field_dict.get(key)
189        if value is None:
190          value = field_dict.setdefault(key, new_default(message))
191        (endpoint, pos) = local_DecodeVarint(buffer, pos)
192        endpoint += pos
193        if endpoint > end:
194          raise _DecodeError('Truncated message.')
195        while pos < endpoint:
196          (element, pos) = decode_value(buffer, pos)
197          value.append(element)
198        if pos > endpoint:
199          del value[-1]   # Discard corrupt value.
200          raise _DecodeError('Packed element was truncated.')
201        return pos
202      return DecodePackedField
203    elif is_repeated:
204      tag_bytes = encoder.TagBytes(field_number, wire_type)
205      tag_len = len(tag_bytes)
206      def DecodeRepeatedField(buffer, pos, end, message, field_dict):
207        value = field_dict.get(key)
208        if value is None:
209          value = field_dict.setdefault(key, new_default(message))
210        while 1:
211          (element, new_pos) = decode_value(buffer, pos)
212          value.append(element)
213          # Predict that the next tag is another copy of the same repeated
214          # field.
215          pos = new_pos + tag_len
216          if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
217            # Prediction failed.  Return.
218            if new_pos > end:
219              raise _DecodeError('Truncated message.')
220            return new_pos
221      return DecodeRepeatedField
222    else:
223      def DecodeField(buffer, pos, end, message, field_dict):
224        (new_value, pos) = decode_value(buffer, pos)
225        if pos > end:
226          raise _DecodeError('Truncated message.')
227        if clear_if_default and not new_value:
228          field_dict.pop(key, None)
229        else:
230          field_dict[key] = new_value
231        return pos
232      return DecodeField
233
234  return SpecificDecoder
235
236
237def _ModifiedDecoder(wire_type, decode_value, modify_value):
238  """Like SimpleDecoder but additionally invokes modify_value on every value
239  before storing it.  Usually modify_value is ZigZagDecode.
240  """
241
242  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
243  # not enough to make a significant difference.
244
245  def InnerDecode(buffer, pos):
246    (result, new_pos) = decode_value(buffer, pos)
247    return (modify_value(result), new_pos)
248  return _SimpleDecoder(wire_type, InnerDecode)
249
250
251def _StructPackDecoder(wire_type, format):
252  """Return a constructor for a decoder for a fixed-width field.
253
254  Args:
255      wire_type:  The field's wire type.
256      format:  The format string to pass to struct.unpack().
257  """
258
259  value_size = struct.calcsize(format)
260  local_unpack = struct.unpack
261
262  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
263  # not enough to make a significant difference.
264
265  # Note that we expect someone up-stack to catch struct.error and convert
266  # it to _DecodeError -- this way we don't have to set up exception-
267  # handling blocks every time we parse one value.
268
269  def InnerDecode(buffer, pos):
270    new_pos = pos + value_size
271    result = local_unpack(format, buffer[pos:new_pos])[0]
272    return (result, new_pos)
273  return _SimpleDecoder(wire_type, InnerDecode)
274
275
276def _FloatDecoder():
277  """Returns a decoder for a float field.
278
279  This code works around a bug in struct.unpack for non-finite 32-bit
280  floating-point values.
281  """
282
283  local_unpack = struct.unpack
284
285  def InnerDecode(buffer, pos):
286    """Decode serialized float to a float and new position.
287
288    Args:
289      buffer: memoryview of the serialized bytes
290      pos: int, position in the memory view to start at.
291
292    Returns:
293      Tuple[float, int] of the deserialized float value and new position
294      in the serialized data.
295    """
296    # We expect a 32-bit value in little-endian byte order.  Bit 1 is the sign
297    # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
298    new_pos = pos + 4
299    float_bytes = buffer[pos:new_pos].tobytes()
300
301    # If this value has all its exponent bits set, then it's non-finite.
302    # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
303    # To avoid that, we parse it specially.
304    if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
305      # If at least one significand bit is set...
306      if float_bytes[0:3] != b'\x00\x00\x80':
307        return (math.nan, new_pos)
308      # If sign bit is set...
309      if float_bytes[3:4] == b'\xFF':
310        return (-math.inf, new_pos)
311      return (math.inf, new_pos)
312
313    # Note that we expect someone up-stack to catch struct.error and convert
314    # it to _DecodeError -- this way we don't have to set up exception-
315    # handling blocks every time we parse one value.
316    result = local_unpack('<f', float_bytes)[0]
317    return (result, new_pos)
318  return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
319
320
321def _DoubleDecoder():
322  """Returns a decoder for a double field.
323
324  This code works around a bug in struct.unpack for not-a-number.
325  """
326
327  local_unpack = struct.unpack
328
329  def InnerDecode(buffer, pos):
330    """Decode serialized double to a double and new position.
331
332    Args:
333      buffer: memoryview of the serialized bytes.
334      pos: int, position in the memory view to start at.
335
336    Returns:
337      Tuple[float, int] of the decoded double value and new position
338      in the serialized data.
339    """
340    # We expect a 64-bit value in little-endian byte order.  Bit 1 is the sign
341    # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
342    new_pos = pos + 8
343    double_bytes = buffer[pos:new_pos].tobytes()
344
345    # If this value has all its exponent bits set and at least one significand
346    # bit set, it's not a number.  In Python 2.4, struct.unpack will treat it
347    # as inf or -inf.  To avoid that, we treat it specially.
348    if ((double_bytes[7:8] in b'\x7F\xFF')
349        and (double_bytes[6:7] >= b'\xF0')
350        and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
351      return (math.nan, new_pos)
352
353    # Note that we expect someone up-stack to catch struct.error and convert
354    # it to _DecodeError -- this way we don't have to set up exception-
355    # handling blocks every time we parse one value.
356    result = local_unpack('<d', double_bytes)[0]
357    return (result, new_pos)
358  return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
359
360
361def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
362                clear_if_default=False):
363  """Returns a decoder for enum field."""
364  enum_type = key.enum_type
365  if is_packed:
366    local_DecodeVarint = _DecodeVarint
367    def DecodePackedField(buffer, pos, end, message, field_dict):
368      """Decode serialized packed enum to its value and a new position.
369
370      Args:
371        buffer: memoryview of the serialized bytes.
372        pos: int, position in the memory view to start at.
373        end: int, end position of serialized data
374        message: Message object to store unknown fields in
375        field_dict: Map[Descriptor, Any] to store decoded values in.
376
377      Returns:
378        int, new position in serialized data.
379      """
380      value = field_dict.get(key)
381      if value is None:
382        value = field_dict.setdefault(key, new_default(message))
383      (endpoint, pos) = local_DecodeVarint(buffer, pos)
384      endpoint += pos
385      if endpoint > end:
386        raise _DecodeError('Truncated message.')
387      while pos < endpoint:
388        value_start_pos = pos
389        (element, pos) = _DecodeSignedVarint32(buffer, pos)
390        # pylint: disable=protected-access
391        if element in enum_type.values_by_number:
392          value.append(element)
393        else:
394          if not message._unknown_fields:
395            message._unknown_fields = []
396          tag_bytes = encoder.TagBytes(field_number,
397                                       wire_format.WIRETYPE_VARINT)
398
399          message._unknown_fields.append(
400              (tag_bytes, buffer[value_start_pos:pos].tobytes()))
401          # pylint: enable=protected-access
402      if pos > endpoint:
403        if element in enum_type.values_by_number:
404          del value[-1]   # Discard corrupt value.
405        else:
406          del message._unknown_fields[-1]
407          # pylint: enable=protected-access
408        raise _DecodeError('Packed element was truncated.')
409      return pos
410    return DecodePackedField
411  elif is_repeated:
412    tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
413    tag_len = len(tag_bytes)
414    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
415      """Decode serialized repeated enum to its value and a new position.
416
417      Args:
418        buffer: memoryview of the serialized bytes.
419        pos: int, position in the memory view to start at.
420        end: int, end position of serialized data
421        message: Message object to store unknown fields in
422        field_dict: Map[Descriptor, Any] to store decoded values in.
423
424      Returns:
425        int, new position in serialized data.
426      """
427      value = field_dict.get(key)
428      if value is None:
429        value = field_dict.setdefault(key, new_default(message))
430      while 1:
431        (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
432        # pylint: disable=protected-access
433        if element in enum_type.values_by_number:
434          value.append(element)
435        else:
436          if not message._unknown_fields:
437            message._unknown_fields = []
438          message._unknown_fields.append(
439              (tag_bytes, buffer[pos:new_pos].tobytes()))
440        # pylint: enable=protected-access
441        # Predict that the next tag is another copy of the same repeated
442        # field.
443        pos = new_pos + tag_len
444        if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
445          # Prediction failed.  Return.
446          if new_pos > end:
447            raise _DecodeError('Truncated message.')
448          return new_pos
449    return DecodeRepeatedField
450  else:
451    def DecodeField(buffer, pos, end, message, field_dict):
452      """Decode serialized repeated enum to its value and a new position.
453
454      Args:
455        buffer: memoryview of the serialized bytes.
456        pos: int, position in the memory view to start at.
457        end: int, end position of serialized data
458        message: Message object to store unknown fields in
459        field_dict: Map[Descriptor, Any] to store decoded values in.
460
461      Returns:
462        int, new position in serialized data.
463      """
464      value_start_pos = pos
465      (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
466      if pos > end:
467        raise _DecodeError('Truncated message.')
468      if clear_if_default and not enum_value:
469        field_dict.pop(key, None)
470        return pos
471      # pylint: disable=protected-access
472      if enum_value in enum_type.values_by_number:
473        field_dict[key] = enum_value
474      else:
475        if not message._unknown_fields:
476          message._unknown_fields = []
477        tag_bytes = encoder.TagBytes(field_number,
478                                     wire_format.WIRETYPE_VARINT)
479        message._unknown_fields.append(
480            (tag_bytes, buffer[value_start_pos:pos].tobytes()))
481        # pylint: enable=protected-access
482      return pos
483    return DecodeField
484
485
486# --------------------------------------------------------------------
487
488
489Int32Decoder = _SimpleDecoder(
490    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
491
492Int64Decoder = _SimpleDecoder(
493    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
494
495UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
496UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
497
498SInt32Decoder = _ModifiedDecoder(
499    wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
500SInt64Decoder = _ModifiedDecoder(
501    wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
502
503# Note that Python conveniently guarantees that when using the '<' prefix on
504# formats, they will also have the same size across all platforms (as opposed
505# to without the prefix, where their sizes depend on the C compiler's basic
506# type sizes).
507Fixed32Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
508Fixed64Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
509SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
510SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
511FloatDecoder = _FloatDecoder()
512DoubleDecoder = _DoubleDecoder()
513
514BoolDecoder = _ModifiedDecoder(
515    wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
516
517
518def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
519                  clear_if_default=False):
520  """Returns a decoder for a string field."""
521
522  local_DecodeVarint = _DecodeVarint
523
524  def _ConvertToUnicode(memview):
525    """Convert byte to unicode."""
526    byte_str = memview.tobytes()
527    try:
528      value = str(byte_str, 'utf-8')
529    except UnicodeDecodeError as e:
530      # add more information to the error message and re-raise it.
531      e.reason = '%s in field: %s' % (e, key.full_name)
532      raise
533
534    return value
535
536  assert not is_packed
537  if is_repeated:
538    tag_bytes = encoder.TagBytes(field_number,
539                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
540    tag_len = len(tag_bytes)
541    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
542      value = field_dict.get(key)
543      if value is None:
544        value = field_dict.setdefault(key, new_default(message))
545      while 1:
546        (size, pos) = local_DecodeVarint(buffer, pos)
547        new_pos = pos + size
548        if new_pos > end:
549          raise _DecodeError('Truncated string.')
550        value.append(_ConvertToUnicode(buffer[pos:new_pos]))
551        # Predict that the next tag is another copy of the same repeated field.
552        pos = new_pos + tag_len
553        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
554          # Prediction failed.  Return.
555          return new_pos
556    return DecodeRepeatedField
557  else:
558    def DecodeField(buffer, pos, end, message, field_dict):
559      (size, pos) = local_DecodeVarint(buffer, pos)
560      new_pos = pos + size
561      if new_pos > end:
562        raise _DecodeError('Truncated string.')
563      if clear_if_default and not size:
564        field_dict.pop(key, None)
565      else:
566        field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
567      return new_pos
568    return DecodeField
569
570
571def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
572                 clear_if_default=False):
573  """Returns a decoder for a bytes field."""
574
575  local_DecodeVarint = _DecodeVarint
576
577  assert not is_packed
578  if is_repeated:
579    tag_bytes = encoder.TagBytes(field_number,
580                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
581    tag_len = len(tag_bytes)
582    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
583      value = field_dict.get(key)
584      if value is None:
585        value = field_dict.setdefault(key, new_default(message))
586      while 1:
587        (size, pos) = local_DecodeVarint(buffer, pos)
588        new_pos = pos + size
589        if new_pos > end:
590          raise _DecodeError('Truncated string.')
591        value.append(buffer[pos:new_pos].tobytes())
592        # Predict that the next tag is another copy of the same repeated field.
593        pos = new_pos + tag_len
594        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
595          # Prediction failed.  Return.
596          return new_pos
597    return DecodeRepeatedField
598  else:
599    def DecodeField(buffer, pos, end, message, field_dict):
600      (size, pos) = local_DecodeVarint(buffer, pos)
601      new_pos = pos + size
602      if new_pos > end:
603        raise _DecodeError('Truncated string.')
604      if clear_if_default and not size:
605        field_dict.pop(key, None)
606      else:
607        field_dict[key] = buffer[pos:new_pos].tobytes()
608      return new_pos
609    return DecodeField
610
611
612def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
613  """Returns a decoder for a group field."""
614
615  end_tag_bytes = encoder.TagBytes(field_number,
616                                   wire_format.WIRETYPE_END_GROUP)
617  end_tag_len = len(end_tag_bytes)
618
619  assert not is_packed
620  if is_repeated:
621    tag_bytes = encoder.TagBytes(field_number,
622                                 wire_format.WIRETYPE_START_GROUP)
623    tag_len = len(tag_bytes)
624    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
625      value = field_dict.get(key)
626      if value is None:
627        value = field_dict.setdefault(key, new_default(message))
628      while 1:
629        value = field_dict.get(key)
630        if value is None:
631          value = field_dict.setdefault(key, new_default(message))
632        # Read sub-message.
633        current_depth += 1
634        if current_depth > _recursion_limit:
635          raise _DecodeError(
636              'Error parsing message: too many levels of nesting.'
637          )
638        pos = value.add()._InternalParse(buffer, pos, end)
639        current_depth -= 1
640        # Read end tag.
641        new_pos = pos+end_tag_len
642        if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
643          raise _DecodeError('Missing group end tag.')
644        # Predict that the next tag is another copy of the same repeated field.
645        pos = new_pos + tag_len
646        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
647          # Prediction failed.  Return.
648          return new_pos
649    return DecodeRepeatedField
650  else:
651    def DecodeField(buffer, pos, end, message, field_dict):
652      value = field_dict.get(key)
653      if value is None:
654        value = field_dict.setdefault(key, new_default(message))
655      # Read sub-message.
656            current_depth += 1
657      if current_depth > _recursion_limit:
658        raise _DecodeError('Error parsing message: too many levels of nesting.')
659      pos = value._InternalParse(buffer, pos, end)
660      current_depth -= 1
661      # Read end tag.
662      new_pos = pos+end_tag_len
663      if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
664        raise _DecodeError('Missing group end tag.')
665      return new_pos
666    return DecodeField
667
668
669def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
670  """Returns a decoder for a message field."""
671
672  local_DecodeVarint = _DecodeVarint
673
674  assert not is_packed
675  if is_repeated:
676    tag_bytes = encoder.TagBytes(field_number,
677                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
678    tag_len = len(tag_bytes)
679    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
680      value = field_dict.get(key)
681      if value is None:
682        value = field_dict.setdefault(key, new_default(message))
683      while 1:
684        # Read length.
685        (size, pos) = local_DecodeVarint(buffer, pos)
686        new_pos = pos + size
687        if new_pos > end:
688          raise _DecodeError('Truncated message.')
689        # Read sub-message.
690        current_depth += 1
691        if current_depth > _recursion_limit:
692          raise _DecodeError(
693              'Error parsing message: too many levels of nesting.'
694          )
695        if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
696          # The only reason _InternalParse would return early is if it
697          # encountered an end-group tag.
698          raise _DecodeError('Unexpected end-group tag.')
699        current_depth -= 1
700        # Predict that the next tag is another copy of the same repeated field.
701        pos = new_pos + tag_len
702        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
703          # Prediction failed.  Return.
704          return new_pos
705    return DecodeRepeatedField
706  else:
707    def DecodeField(buffer, pos, end, message, field_dict):
708      value = field_dict.get(key)
709      if value is None:
710        value = field_dict.setdefault(key, new_default(message))
711      # Read length.
712      (size, pos) = local_DecodeVarint(buffer, pos)
713      new_pos = pos + size
714      if new_pos > end:
715        raise _DecodeError('Truncated message.')
716      # Read sub-message.
717      current_depth += 1
718      if current_depth > _recursion_limit:
719        raise _DecodeError('Error parsing message: too many levels of nesting.')
720      if value._InternalParse(buffer, pos, new_pos) != new_pos:
721        # The only reason _InternalParse would return early is if it encountered
722        # an end-group tag.
723        raise _DecodeError('Unexpected end-group tag.')
724      current_depth -= 1
725      return new_pos
726    return DecodeField
727
728
729# --------------------------------------------------------------------
730
731MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
732
733def MessageSetItemDecoder(descriptor):
734  """Returns a decoder for a MessageSet item.
735
736  The parameter is the message Descriptor.
737
738  The message set message looks like this:
739    message MessageSet {
740      repeated group Item = 1 {
741        required int32 type_id = 2;
742        required string message = 3;
743      }
744    }
745  """
746
747  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
748  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
749  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
750
751  local_ReadTag = ReadTag
752  local_DecodeVarint = _DecodeVarint
753  local_SkipField = SkipField
754
755  def DecodeItem(buffer, pos, end, message, field_dict):
756    """Decode serialized message set to its value and new position.
757
758    Args:
759      buffer: memoryview of the serialized bytes.
760      pos: int, position in the memory view to start at.
761      end: int, end position of serialized data
762      message: Message object to store unknown fields in
763      field_dict: Map[Descriptor, Any] to store decoded values in.
764
765    Returns:
766      int, new position in serialized data.
767    """
768    message_set_item_start = pos
769    type_id = -1
770    message_start = -1
771    message_end = -1
772
773    # Technically, type_id and message can appear in any order, so we need
774    # a little loop here.
775    while 1:
776      (tag_bytes, pos) = local_ReadTag(buffer, pos)
777      if tag_bytes == type_id_tag_bytes:
778        (type_id, pos) = local_DecodeVarint(buffer, pos)
779      elif tag_bytes == message_tag_bytes:
780        (size, message_start) = local_DecodeVarint(buffer, pos)
781        pos = message_end = message_start + size
782      elif tag_bytes == item_end_tag_bytes:
783        break
784      else:
785        pos = SkipField(buffer, pos, end, tag_bytes)
786        if pos == -1:
787          raise _DecodeError('Missing group end tag.')
788
789    if pos > end:
790      raise _DecodeError('Truncated message.')
791
792    if type_id == -1:
793      raise _DecodeError('MessageSet item missing type_id.')
794    if message_start == -1:
795      raise _DecodeError('MessageSet item missing message.')
796
797    extension = message.Extensions._FindExtensionByNumber(type_id)
798    # pylint: disable=protected-access
799    if extension is not None:
800      value = field_dict.get(extension)
801      if value is None:
802        message_type = extension.message_type
803        if not hasattr(message_type, '_concrete_class'):
804          message_factory.GetMessageClass(message_type)
805        value = field_dict.setdefault(
806            extension, message_type._concrete_class())
807      if value._InternalParse(buffer, message_start,message_end) != message_end:
808        # The only reason _InternalParse would return early is if it encountered
809        # an end-group tag.
810        raise _DecodeError('Unexpected end-group tag.')
811    else:
812      if not message._unknown_fields:
813        message._unknown_fields = []
814      message._unknown_fields.append(
815          (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
816      # pylint: enable=protected-access
817
818    return pos
819
820  return DecodeItem
821
822
823def UnknownMessageSetItemDecoder():
824  """Returns a decoder for a Unknown MessageSet item."""
825
826  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
827  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
828  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
829
830  def DecodeUnknownItem(buffer):
831    pos = 0
832    end = len(buffer)
833    message_start = -1
834    message_end = -1
835    while 1:
836      (tag_bytes, pos) = ReadTag(buffer, pos)
837      if tag_bytes == type_id_tag_bytes:
838        (type_id, pos) = _DecodeVarint(buffer, pos)
839      elif tag_bytes == message_tag_bytes:
840        (size, message_start) = _DecodeVarint(buffer, pos)
841        pos = message_end = message_start + size
842      elif tag_bytes == item_end_tag_bytes:
843        break
844      else:
845        pos = SkipField(buffer, pos, end, tag_bytes)
846        if pos == -1:
847          raise _DecodeError('Missing group end tag.')
848
849    if pos > end:
850      raise _DecodeError('Truncated message.')
851
852    if type_id == -1:
853      raise _DecodeError('MessageSet item missing type_id.')
854    if message_start == -1:
855      raise _DecodeError('MessageSet item missing message.')
856
857    return (type_id, buffer[message_start:message_end].tobytes())
858
859  return DecodeUnknownItem
860
861# --------------------------------------------------------------------
862
863def MapDecoder(field_descriptor, new_default, is_message_map):
864  """Returns a decoder for a map field."""
865
866  key = field_descriptor
867  tag_bytes = encoder.TagBytes(field_descriptor.number,
868                               wire_format.WIRETYPE_LENGTH_DELIMITED)
869  tag_len = len(tag_bytes)
870  local_DecodeVarint = _DecodeVarint
871  # Can't read _concrete_class yet; might not be initialized.
872  message_type = field_descriptor.message_type
873
874  def DecodeMap(buffer, pos, end, message, field_dict):
875    submsg = message_type._concrete_class()
876    value = field_dict.get(key)
877    if value is None:
878      value = field_dict.setdefault(key, new_default(message))
879    while 1:
880      # Read length.
881      (size, pos) = local_DecodeVarint(buffer, pos)
882      new_pos = pos + size
883      if new_pos > end:
884        raise _DecodeError('Truncated message.')
885      # Read sub-message.
886      submsg.Clear()
887      if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
888        # The only reason _InternalParse would return early is if it
889        # encountered an end-group tag.
890        raise _DecodeError('Unexpected end-group tag.')
891
892      if is_message_map:
893        value[submsg.key].CopyFrom(submsg.value)
894      else:
895        value[submsg.key] = submsg.value
896
897      # Predict that the next tag is another copy of the same repeated field.
898      pos = new_pos + tag_len
899      if buffer[new_pos:pos] != tag_bytes or new_pos == end:
900        # Prediction failed.  Return.
901        return new_pos
902
903  return DecodeMap
904
905# --------------------------------------------------------------------
906# Optimization is not as heavy here because calls to SkipField() are rare,
907# except for handling end-group tags.
908
909def _SkipVarint(buffer, pos, end):
910  """Skip a varint value.  Returns the new position."""
911  # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
912  # With this code, ord(b'') raises TypeError.  Both are handled in
913  # python_message.py to generate a 'Truncated message' error.
914  while ord(buffer[pos:pos+1].tobytes()) & 0x80:
915    pos += 1
916  pos += 1
917  if pos > end:
918    raise _DecodeError('Truncated message.')
919  return pos
920
921def _SkipFixed64(buffer, pos, end):
922  """Skip a fixed64 value.  Returns the new position."""
923
924  pos += 8
925  if pos > end:
926    raise _DecodeError('Truncated message.')
927  return pos
928
929
930def _DecodeFixed64(buffer, pos):
931  """Decode a fixed64."""
932  new_pos = pos + 8
933  return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
934
935
936def _SkipLengthDelimited(buffer, pos, end):
937  """Skip a length-delimited value.  Returns the new position."""
938
939  (size, pos) = _DecodeVarint(buffer, pos)
940  pos += size
941  if pos > end:
942    raise _DecodeError('Truncated message.')
943  return pos
944
945
946def _SkipGroup(buffer, pos, end):
947  """Skip sub-group.  Returns the new position."""
948
949  while 1:
950    (tag_bytes, pos) = ReadTag(buffer, pos)
951    new_pos = SkipField(buffer, pos, end, tag_bytes)
952    if new_pos == -1:
953      return pos
954    pos = new_pos
955
956
957def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
958  """Decode UnknownFieldSet.  Returns the UnknownFieldSet and new position."""
959
960  unknown_field_set = containers.UnknownFieldSet()
961  while end_pos is None or pos < end_pos:
962    (tag_bytes, pos) = ReadTag(buffer, pos)
963    (tag, _) = _DecodeVarint(tag_bytes, 0)
964    field_number, wire_type = wire_format.UnpackTag(tag)
965    if wire_type == wire_format.WIRETYPE_END_GROUP:
966      break
967    (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
968    # pylint: disable=protected-access
969    unknown_field_set._add(field_number, wire_type, data)
970
971  return (unknown_field_set, pos)
972
973
974def _DecodeUnknownField(buffer, pos, wire_type):
975  """Decode a unknown field.  Returns the UnknownField and new position."""
976
977  if wire_type == wire_format.WIRETYPE_VARINT:
978    (data, pos) = _DecodeVarint(buffer, pos)
979  elif wire_type == wire_format.WIRETYPE_FIXED64:
980    (data, pos) = _DecodeFixed64(buffer, pos)
981  elif wire_type == wire_format.WIRETYPE_FIXED32:
982    (data, pos) = _DecodeFixed32(buffer, pos)
983  elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
984    (size, pos) = _DecodeVarint(buffer, pos)
985    data = buffer[pos:pos+size].tobytes()
986    pos += size
987  elif wire_type == wire_format.WIRETYPE_START_GROUP:
988    current_depth += 1
989    if current_depth >= _recursion_limit:
990      raise _DecodeError('Error parsing message: too many levels of nesting.')
991    (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
992    current_depth -= 1
993  elif wire_type == wire_format.WIRETYPE_END_GROUP:
994    return (0, -1)
995  else:
996    raise _DecodeError('Wrong wire type in tag.')
997
998  return (data, pos)
999
1000
1001def _EndGroup(buffer, pos, end):
1002  """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
1003
1004  return -1
1005
1006
1007def _SkipFixed32(buffer, pos, end):
1008  """Skip a fixed32 value.  Returns the new position."""
1009
1010  pos += 4
1011  if pos > end:
1012    raise _DecodeError('Truncated message.')
1013  return pos
1014
1015
1016def _DecodeFixed32(buffer, pos):
1017  """Decode a fixed32."""
1018
1019  new_pos = pos + 4
1020  return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
1021
1022DEFAULT_RECURSION_LIMIT = 100
1023_recursion_limit = DEFAULT_RECURSION_LIMIT
1024
1025
1026def SetRecursionLimit(new_limit):
1027  global _recursion_limit
1028  _recursion_limit = new_limit
1029
1030
1031def _RaiseInvalidWireType(buffer, pos, end):
1032  """Skip function for unknown wire types.  Raises an exception."""
1033
1034  raise _DecodeError('Tag had invalid wire type.')
1035
1036def _FieldSkipper():
1037  """Constructs the SkipField function."""
1038
1039  WIRETYPE_TO_SKIPPER = [
1040      _SkipVarint,
1041      _SkipFixed64,
1042      _SkipLengthDelimited,
1043      _SkipGroup,
1044      _EndGroup,
1045      _SkipFixed32,
1046      _RaiseInvalidWireType,
1047      _RaiseInvalidWireType,
1048      ]
1049
1050  wiretype_mask = wire_format.TAG_TYPE_MASK
1051
1052  def SkipField(buffer, pos, end, tag_bytes):
1053    """Skips a field with the specified tag.
1054
1055    |pos| should point to the byte immediately after the tag.
1056
1057    Returns:
1058        The new position (after the tag value), or -1 if the tag is an end-group
1059        tag (in which case the calling loop should break).
1060    """
1061
1062    # The wire type is always in the first byte since varints are little-endian.
1063    wire_type = ord(tag_bytes[0:1]) & wiretype_mask
1064    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
1065
1066  return SkipField
1067
1068SkipField = _FieldSkipper()
1069