• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Contains well known classes.
32
33This files defines well known classes which need extra maintenance including:
34  - Any
35  - Duration
36  - FieldMask
37  - Struct
38  - Timestamp
39"""
40
41__author__ = 'jieluo@google.com (Jie Luo)'
42
43import calendar
44from datetime import datetime
45from datetime import timedelta
46import six
47
48try:
49  # Since python 3
50  import collections.abc as collections_abc
51except ImportError:
52  # Won't work after python 3.8
53  import collections as collections_abc
54
55from google.protobuf.descriptor import FieldDescriptor
56
57_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
58_NANOS_PER_SECOND = 1000000000
59_NANOS_PER_MILLISECOND = 1000000
60_NANOS_PER_MICROSECOND = 1000
61_MILLIS_PER_SECOND = 1000
62_MICROS_PER_SECOND = 1000000
63_SECONDS_PER_DAY = 24 * 3600
64_DURATION_SECONDS_MAX = 315576000000
65
66
67class Any(object):
68  """Class for Any Message type."""
69
70  __slots__ = ()
71
72  def Pack(self, msg, type_url_prefix='type.googleapis.com/',
73           deterministic=None):
74    """Packs the specified message into current Any message."""
75    if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/':
76      self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
77    else:
78      self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
79    self.value = msg.SerializeToString(deterministic=deterministic)
80
81  def Unpack(self, msg):
82    """Unpacks the current Any message into specified message."""
83    descriptor = msg.DESCRIPTOR
84    if not self.Is(descriptor):
85      return False
86    msg.ParseFromString(self.value)
87    return True
88
89  def TypeName(self):
90    """Returns the protobuf type name of the inner message."""
91    # Only last part is to be used: b/25630112
92    return self.type_url.split('/')[-1]
93
94  def Is(self, descriptor):
95    """Checks if this Any represents the given protobuf type."""
96    return '/' in self.type_url and self.TypeName() == descriptor.full_name
97
98
99_EPOCH_DATETIME = datetime.utcfromtimestamp(0)
100
101
102class Timestamp(object):
103  """Class for Timestamp message type."""
104
105  __slots__ = ()
106
107  def ToJsonString(self):
108    """Converts Timestamp to RFC 3339 date string format.
109
110    Returns:
111      A string converted from timestamp. The string is always Z-normalized
112      and uses 3, 6 or 9 fractional digits as required to represent the
113      exact time. Example of the return format: '1972-01-01T10:00:20.021Z'
114    """
115    nanos = self.nanos % _NANOS_PER_SECOND
116    total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND
117    seconds = total_sec % _SECONDS_PER_DAY
118    days = (total_sec - seconds) // _SECONDS_PER_DAY
119    dt = datetime(1970, 1, 1) + timedelta(days, seconds)
120
121    result = dt.isoformat()
122    if (nanos % 1e9) == 0:
123      # If there are 0 fractional digits, the fractional
124      # point '.' should be omitted when serializing.
125      return result + 'Z'
126    if (nanos % 1e6) == 0:
127      # Serialize 3 fractional digits.
128      return result + '.%03dZ' % (nanos / 1e6)
129    if (nanos % 1e3) == 0:
130      # Serialize 6 fractional digits.
131      return result + '.%06dZ' % (nanos / 1e3)
132    # Serialize 9 fractional digits.
133    return result + '.%09dZ' % nanos
134
135  def FromJsonString(self, value):
136    """Parse a RFC 3339 date string format to Timestamp.
137
138    Args:
139      value: A date string. Any fractional digits (or none) and any offset are
140          accepted as long as they fit into nano-seconds precision.
141          Example of accepted format: '1972-01-01T10:00:20.021-05:00'
142
143    Raises:
144      ValueError: On parsing problems.
145    """
146    timezone_offset = value.find('Z')
147    if timezone_offset == -1:
148      timezone_offset = value.find('+')
149    if timezone_offset == -1:
150      timezone_offset = value.rfind('-')
151    if timezone_offset == -1:
152      raise ValueError(
153          'Failed to parse timestamp: missing valid timezone offset.')
154    time_value = value[0:timezone_offset]
155    # Parse datetime and nanos.
156    point_position = time_value.find('.')
157    if point_position == -1:
158      second_value = time_value
159      nano_value = ''
160    else:
161      second_value = time_value[:point_position]
162      nano_value = time_value[point_position + 1:]
163    date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT)
164    td = date_object - datetime(1970, 1, 1)
165    seconds = td.seconds + td.days * _SECONDS_PER_DAY
166    if len(nano_value) > 9:
167      raise ValueError(
168          'Failed to parse Timestamp: nanos {0} more than '
169          '9 fractional digits.'.format(nano_value))
170    if nano_value:
171      nanos = round(float('0.' + nano_value) * 1e9)
172    else:
173      nanos = 0
174    # Parse timezone offsets.
175    if value[timezone_offset] == 'Z':
176      if len(value) != timezone_offset + 1:
177        raise ValueError('Failed to parse timestamp: invalid trailing'
178                         ' data {0}.'.format(value))
179    else:
180      timezone = value[timezone_offset:]
181      pos = timezone.find(':')
182      if pos == -1:
183        raise ValueError(
184            'Invalid timezone offset value: {0}.'.format(timezone))
185      if timezone[0] == '+':
186        seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
187      else:
188        seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
189    # Set seconds and nanos
190    self.seconds = int(seconds)
191    self.nanos = int(nanos)
192
193  def GetCurrentTime(self):
194    """Get the current UTC into Timestamp."""
195    self.FromDatetime(datetime.utcnow())
196
197  def ToNanoseconds(self):
198    """Converts Timestamp to nanoseconds since epoch."""
199    return self.seconds * _NANOS_PER_SECOND + self.nanos
200
201  def ToMicroseconds(self):
202    """Converts Timestamp to microseconds since epoch."""
203    return (self.seconds * _MICROS_PER_SECOND +
204            self.nanos // _NANOS_PER_MICROSECOND)
205
206  def ToMilliseconds(self):
207    """Converts Timestamp to milliseconds since epoch."""
208    return (self.seconds * _MILLIS_PER_SECOND +
209            self.nanos // _NANOS_PER_MILLISECOND)
210
211  def ToSeconds(self):
212    """Converts Timestamp to seconds since epoch."""
213    return self.seconds
214
215  def FromNanoseconds(self, nanos):
216    """Converts nanoseconds since epoch to Timestamp."""
217    self.seconds = nanos // _NANOS_PER_SECOND
218    self.nanos = nanos % _NANOS_PER_SECOND
219
220  def FromMicroseconds(self, micros):
221    """Converts microseconds since epoch to Timestamp."""
222    self.seconds = micros // _MICROS_PER_SECOND
223    self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND
224
225  def FromMilliseconds(self, millis):
226    """Converts milliseconds since epoch to Timestamp."""
227    self.seconds = millis // _MILLIS_PER_SECOND
228    self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND
229
230  def FromSeconds(self, seconds):
231    """Converts seconds since epoch to Timestamp."""
232    self.seconds = seconds
233    self.nanos = 0
234
235  def ToDatetime(self):
236    """Converts Timestamp to datetime."""
237    return _EPOCH_DATETIME + timedelta(
238        seconds=self.seconds, microseconds=_RoundTowardZero(
239            self.nanos, _NANOS_PER_MICROSECOND))
240
241  def FromDatetime(self, dt):
242    """Converts datetime to Timestamp."""
243    # Using this guide: http://wiki.python.org/moin/WorkingWithTime
244    # And this conversion guide: http://docs.python.org/library/time.html
245
246    # Turn the date parameter into a tuple (struct_time) that can then be
247    # manipulated into a long value of seconds.  During the conversion from
248    # struct_time to long, the source date in UTC, and so it follows that the
249    # correct transformation is calendar.timegm()
250    self.seconds = calendar.timegm(dt.utctimetuple())
251    self.nanos = dt.microsecond * _NANOS_PER_MICROSECOND
252
253
254class Duration(object):
255  """Class for Duration message type."""
256
257  __slots__ = ()
258
259  def ToJsonString(self):
260    """Converts Duration to string format.
261
262    Returns:
263      A string converted from self. The string format will contains
264      3, 6, or 9 fractional digits depending on the precision required to
265      represent the exact Duration value. For example: "1s", "1.010s",
266      "1.000000100s", "-3.100s"
267    """
268    _CheckDurationValid(self.seconds, self.nanos)
269    if self.seconds < 0 or self.nanos < 0:
270      result = '-'
271      seconds = - self.seconds + int((0 - self.nanos) // 1e9)
272      nanos = (0 - self.nanos) % 1e9
273    else:
274      result = ''
275      seconds = self.seconds + int(self.nanos // 1e9)
276      nanos = self.nanos % 1e9
277    result += '%d' % seconds
278    if (nanos % 1e9) == 0:
279      # If there are 0 fractional digits, the fractional
280      # point '.' should be omitted when serializing.
281      return result + 's'
282    if (nanos % 1e6) == 0:
283      # Serialize 3 fractional digits.
284      return result + '.%03ds' % (nanos / 1e6)
285    if (nanos % 1e3) == 0:
286      # Serialize 6 fractional digits.
287      return result + '.%06ds' % (nanos / 1e3)
288    # Serialize 9 fractional digits.
289    return result + '.%09ds' % nanos
290
291  def FromJsonString(self, value):
292    """Converts a string to Duration.
293
294    Args:
295      value: A string to be converted. The string must end with 's'. Any
296          fractional digits (or none) are accepted as long as they fit into
297          precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
298
299    Raises:
300      ValueError: On parsing problems.
301    """
302    if len(value) < 1 or value[-1] != 's':
303      raise ValueError(
304          'Duration must end with letter "s": {0}.'.format(value))
305    try:
306      pos = value.find('.')
307      if pos == -1:
308        seconds = int(value[:-1])
309        nanos = 0
310      else:
311        seconds = int(value[:pos])
312        if value[0] == '-':
313          nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
314        else:
315          nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
316      _CheckDurationValid(seconds, nanos)
317      self.seconds = seconds
318      self.nanos = nanos
319    except ValueError as e:
320      raise ValueError(
321          'Couldn\'t parse duration: {0} : {1}.'.format(value, e))
322
323  def ToNanoseconds(self):
324    """Converts a Duration to nanoseconds."""
325    return self.seconds * _NANOS_PER_SECOND + self.nanos
326
327  def ToMicroseconds(self):
328    """Converts a Duration to microseconds."""
329    micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND)
330    return self.seconds * _MICROS_PER_SECOND + micros
331
332  def ToMilliseconds(self):
333    """Converts a Duration to milliseconds."""
334    millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND)
335    return self.seconds * _MILLIS_PER_SECOND + millis
336
337  def ToSeconds(self):
338    """Converts a Duration to seconds."""
339    return self.seconds
340
341  def FromNanoseconds(self, nanos):
342    """Converts nanoseconds to Duration."""
343    self._NormalizeDuration(nanos // _NANOS_PER_SECOND,
344                            nanos % _NANOS_PER_SECOND)
345
346  def FromMicroseconds(self, micros):
347    """Converts microseconds to Duration."""
348    self._NormalizeDuration(
349        micros // _MICROS_PER_SECOND,
350        (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND)
351
352  def FromMilliseconds(self, millis):
353    """Converts milliseconds to Duration."""
354    self._NormalizeDuration(
355        millis // _MILLIS_PER_SECOND,
356        (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND)
357
358  def FromSeconds(self, seconds):
359    """Converts seconds to Duration."""
360    self.seconds = seconds
361    self.nanos = 0
362
363  def ToTimedelta(self):
364    """Converts Duration to timedelta."""
365    return timedelta(
366        seconds=self.seconds, microseconds=_RoundTowardZero(
367            self.nanos, _NANOS_PER_MICROSECOND))
368
369  def FromTimedelta(self, td):
370    """Converts timedelta to Duration."""
371    self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
372                            td.microseconds * _NANOS_PER_MICROSECOND)
373
374  def _NormalizeDuration(self, seconds, nanos):
375    """Set Duration by seconds and nanos."""
376    # Force nanos to be negative if the duration is negative.
377    if seconds < 0 and nanos > 0:
378      seconds += 1
379      nanos -= _NANOS_PER_SECOND
380    self.seconds = seconds
381    self.nanos = nanos
382
383
384def _CheckDurationValid(seconds, nanos):
385  if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:
386    raise ValueError(
387        'Duration is not valid: Seconds {0} must be in range '
388        '[-315576000000, 315576000000].'.format(seconds))
389  if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND:
390    raise ValueError(
391        'Duration is not valid: Nanos {0} must be in range '
392        '[-999999999, 999999999].'.format(nanos))
393  if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0):
394    raise ValueError(
395        'Duration is not valid: Sign mismatch.')
396
397
398def _RoundTowardZero(value, divider):
399  """Truncates the remainder part after division."""
400  # For some languanges, the sign of the remainder is implementation
401  # dependent if any of the operands is negative. Here we enforce
402  # "rounded toward zero" semantics. For example, for (-5) / 2 an
403  # implementation may give -3 as the result with the remainder being
404  # 1. This function ensures we always return -2 (closer to zero).
405  result = value // divider
406  remainder = value % divider
407  if result < 0 and remainder > 0:
408    return result + 1
409  else:
410    return result
411
412
413class FieldMask(object):
414  """Class for FieldMask message type."""
415
416  __slots__ = ()
417
418  def ToJsonString(self):
419    """Converts FieldMask to string according to proto3 JSON spec."""
420    camelcase_paths = []
421    for path in self.paths:
422      camelcase_paths.append(_SnakeCaseToCamelCase(path))
423    return ','.join(camelcase_paths)
424
425  def FromJsonString(self, value):
426    """Converts string to FieldMask according to proto3 JSON spec."""
427    self.Clear()
428    if value:
429      for path in value.split(','):
430        self.paths.append(_CamelCaseToSnakeCase(path))
431
432  def IsValidForDescriptor(self, message_descriptor):
433    """Checks whether the FieldMask is valid for Message Descriptor."""
434    for path in self.paths:
435      if not _IsValidPath(message_descriptor, path):
436        return False
437    return True
438
439  def AllFieldsFromDescriptor(self, message_descriptor):
440    """Gets all direct fields of Message Descriptor to FieldMask."""
441    self.Clear()
442    for field in message_descriptor.fields:
443      self.paths.append(field.name)
444
445  def CanonicalFormFromMask(self, mask):
446    """Converts a FieldMask to the canonical form.
447
448    Removes paths that are covered by another path. For example,
449    "foo.bar" is covered by "foo" and will be removed if "foo"
450    is also in the FieldMask. Then sorts all paths in alphabetical order.
451
452    Args:
453      mask: The original FieldMask to be converted.
454    """
455    tree = _FieldMaskTree(mask)
456    tree.ToFieldMask(self)
457
458  def Union(self, mask1, mask2):
459    """Merges mask1 and mask2 into this FieldMask."""
460    _CheckFieldMaskMessage(mask1)
461    _CheckFieldMaskMessage(mask2)
462    tree = _FieldMaskTree(mask1)
463    tree.MergeFromFieldMask(mask2)
464    tree.ToFieldMask(self)
465
466  def Intersect(self, mask1, mask2):
467    """Intersects mask1 and mask2 into this FieldMask."""
468    _CheckFieldMaskMessage(mask1)
469    _CheckFieldMaskMessage(mask2)
470    tree = _FieldMaskTree(mask1)
471    intersection = _FieldMaskTree()
472    for path in mask2.paths:
473      tree.IntersectPath(path, intersection)
474    intersection.ToFieldMask(self)
475
476  def MergeMessage(
477      self, source, destination,
478      replace_message_field=False, replace_repeated_field=False):
479    """Merges fields specified in FieldMask from source to destination.
480
481    Args:
482      source: Source message.
483      destination: The destination message to be merged into.
484      replace_message_field: Replace message field if True. Merge message
485          field if False.
486      replace_repeated_field: Replace repeated field if True. Append
487          elements of repeated field if False.
488    """
489    tree = _FieldMaskTree(self)
490    tree.MergeMessage(
491        source, destination, replace_message_field, replace_repeated_field)
492
493
494def _IsValidPath(message_descriptor, path):
495  """Checks whether the path is valid for Message Descriptor."""
496  parts = path.split('.')
497  last = parts.pop()
498  for name in parts:
499    field = message_descriptor.fields_by_name.get(name)
500    if (field is None or
501        field.label == FieldDescriptor.LABEL_REPEATED or
502        field.type != FieldDescriptor.TYPE_MESSAGE):
503      return False
504    message_descriptor = field.message_type
505  return last in message_descriptor.fields_by_name
506
507
508def _CheckFieldMaskMessage(message):
509  """Raises ValueError if message is not a FieldMask."""
510  message_descriptor = message.DESCRIPTOR
511  if (message_descriptor.name != 'FieldMask' or
512      message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
513    raise ValueError('Message {0} is not a FieldMask.'.format(
514        message_descriptor.full_name))
515
516
517def _SnakeCaseToCamelCase(path_name):
518  """Converts a path name from snake_case to camelCase."""
519  result = []
520  after_underscore = False
521  for c in path_name:
522    if c.isupper():
523      raise ValueError(
524          'Fail to print FieldMask to Json string: Path name '
525          '{0} must not contain uppercase letters.'.format(path_name))
526    if after_underscore:
527      if c.islower():
528        result.append(c.upper())
529        after_underscore = False
530      else:
531        raise ValueError(
532            'Fail to print FieldMask to Json string: The '
533            'character after a "_" must be a lowercase letter '
534            'in path name {0}.'.format(path_name))
535    elif c == '_':
536      after_underscore = True
537    else:
538      result += c
539
540  if after_underscore:
541    raise ValueError('Fail to print FieldMask to Json string: Trailing "_" '
542                     'in path name {0}.'.format(path_name))
543  return ''.join(result)
544
545
546def _CamelCaseToSnakeCase(path_name):
547  """Converts a field name from camelCase to snake_case."""
548  result = []
549  for c in path_name:
550    if c == '_':
551      raise ValueError('Fail to parse FieldMask: Path name '
552                       '{0} must not contain "_"s.'.format(path_name))
553    if c.isupper():
554      result += '_'
555      result += c.lower()
556    else:
557      result += c
558  return ''.join(result)
559
560
561class _FieldMaskTree(object):
562  """Represents a FieldMask in a tree structure.
563
564  For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
565  the FieldMaskTree will be:
566      [_root] -+- foo -+- bar
567            |       |
568            |       +- baz
569            |
570            +- bar --- baz
571  In the tree, each leaf node represents a field path.
572  """
573
574  __slots__ = ('_root',)
575
576  def __init__(self, field_mask=None):
577    """Initializes the tree by FieldMask."""
578    self._root = {}
579    if field_mask:
580      self.MergeFromFieldMask(field_mask)
581
582  def MergeFromFieldMask(self, field_mask):
583    """Merges a FieldMask to the tree."""
584    for path in field_mask.paths:
585      self.AddPath(path)
586
587  def AddPath(self, path):
588    """Adds a field path into the tree.
589
590    If the field path to add is a sub-path of an existing field path
591    in the tree (i.e., a leaf node), it means the tree already matches
592    the given path so nothing will be added to the tree. If the path
593    matches an existing non-leaf node in the tree, that non-leaf node
594    will be turned into a leaf node with all its children removed because
595    the path matches all the node's children. Otherwise, a new path will
596    be added.
597
598    Args:
599      path: The field path to add.
600    """
601    node = self._root
602    for name in path.split('.'):
603      if name not in node:
604        node[name] = {}
605      elif not node[name]:
606        # Pre-existing empty node implies we already have this entire tree.
607        return
608      node = node[name]
609    # Remove any sub-trees we might have had.
610    node.clear()
611
612  def ToFieldMask(self, field_mask):
613    """Converts the tree to a FieldMask."""
614    field_mask.Clear()
615    _AddFieldPaths(self._root, '', field_mask)
616
617  def IntersectPath(self, path, intersection):
618    """Calculates the intersection part of a field path with this tree.
619
620    Args:
621      path: The field path to calculates.
622      intersection: The out tree to record the intersection part.
623    """
624    node = self._root
625    for name in path.split('.'):
626      if name not in node:
627        return
628      elif not node[name]:
629        intersection.AddPath(path)
630        return
631      node = node[name]
632    intersection.AddLeafNodes(path, node)
633
634  def AddLeafNodes(self, prefix, node):
635    """Adds leaf nodes begin with prefix to this tree."""
636    if not node:
637      self.AddPath(prefix)
638    for name in node:
639      child_path = prefix + '.' + name
640      self.AddLeafNodes(child_path, node[name])
641
642  def MergeMessage(
643      self, source, destination,
644      replace_message, replace_repeated):
645    """Merge all fields specified by this tree from source to destination."""
646    _MergeMessage(
647        self._root, source, destination, replace_message, replace_repeated)
648
649
650def _StrConvert(value):
651  """Converts value to str if it is not."""
652  # This file is imported by c extension and some methods like ClearField
653  # requires string for the field name. py2/py3 has different text
654  # type and may use unicode.
655  if not isinstance(value, str):
656    return value.encode('utf-8')
657  return value
658
659
660def _MergeMessage(
661    node, source, destination, replace_message, replace_repeated):
662  """Merge all fields specified by a sub-tree from source to destination."""
663  source_descriptor = source.DESCRIPTOR
664  for name in node:
665    child = node[name]
666    field = source_descriptor.fields_by_name[name]
667    if field is None:
668      raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
669          name, source_descriptor.full_name))
670    if child:
671      # Sub-paths are only allowed for singular message fields.
672      if (field.label == FieldDescriptor.LABEL_REPEATED or
673          field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
674        raise ValueError('Error: Field {0} in message {1} is not a singular '
675                         'message field and cannot have sub-fields.'.format(
676                             name, source_descriptor.full_name))
677      if source.HasField(name):
678        _MergeMessage(
679            child, getattr(source, name), getattr(destination, name),
680            replace_message, replace_repeated)
681      continue
682    if field.label == FieldDescriptor.LABEL_REPEATED:
683      if replace_repeated:
684        destination.ClearField(_StrConvert(name))
685      repeated_source = getattr(source, name)
686      repeated_destination = getattr(destination, name)
687      repeated_destination.MergeFrom(repeated_source)
688    else:
689      if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
690        if replace_message:
691          destination.ClearField(_StrConvert(name))
692        if source.HasField(name):
693          getattr(destination, name).MergeFrom(getattr(source, name))
694      else:
695        setattr(destination, name, getattr(source, name))
696
697
698def _AddFieldPaths(node, prefix, field_mask):
699  """Adds the field paths descended from node to field_mask."""
700  if not node and prefix:
701    field_mask.paths.append(prefix)
702    return
703  for name in sorted(node):
704    if prefix:
705      child_path = prefix + '.' + name
706    else:
707      child_path = name
708    _AddFieldPaths(node[name], child_path, field_mask)
709
710
711_INT_OR_FLOAT = six.integer_types + (float,)
712
713
714def _SetStructValue(struct_value, value):
715  if value is None:
716    struct_value.null_value = 0
717  elif isinstance(value, bool):
718    # Note: this check must come before the number check because in Python
719    # True and False are also considered numbers.
720    struct_value.bool_value = value
721  elif isinstance(value, six.string_types):
722    struct_value.string_value = value
723  elif isinstance(value, _INT_OR_FLOAT):
724    struct_value.number_value = value
725  elif isinstance(value, dict):
726    struct_value.struct_value.Clear()
727    struct_value.struct_value.update(value)
728  elif isinstance(value, list):
729    struct_value.list_value.Clear()
730    struct_value.list_value.extend(value)
731  else:
732    raise ValueError('Unexpected type')
733
734
735def _GetStructValue(struct_value):
736  which = struct_value.WhichOneof('kind')
737  if which == 'struct_value':
738    return struct_value.struct_value
739  elif which == 'null_value':
740    return None
741  elif which == 'number_value':
742    return struct_value.number_value
743  elif which == 'string_value':
744    return struct_value.string_value
745  elif which == 'bool_value':
746    return struct_value.bool_value
747  elif which == 'list_value':
748    return struct_value.list_value
749  elif which is None:
750    raise ValueError('Value not set')
751
752
753class Struct(object):
754  """Class for Struct message type."""
755
756  __slots__ = ()
757
758  def __getitem__(self, key):
759    return _GetStructValue(self.fields[key])
760
761  def __contains__(self, item):
762    return item in self.fields
763
764  def __setitem__(self, key, value):
765    _SetStructValue(self.fields[key], value)
766
767  def __delitem__(self, key):
768    del self.fields[key]
769
770  def __len__(self):
771    return len(self.fields)
772
773  def __iter__(self):
774    return iter(self.fields)
775
776  def keys(self):  # pylint: disable=invalid-name
777    return self.fields.keys()
778
779  def values(self):  # pylint: disable=invalid-name
780    return [self[key] for key in self]
781
782  def items(self):  # pylint: disable=invalid-name
783    return [(key, self[key]) for key in self]
784
785  def get_or_create_list(self, key):
786    """Returns a list for this key, creating if it didn't exist already."""
787    if not self.fields[key].HasField('list_value'):
788      # Clear will mark list_value modified which will indeed create a list.
789      self.fields[key].list_value.Clear()
790    return self.fields[key].list_value
791
792  def get_or_create_struct(self, key):
793    """Returns a struct for this key, creating if it didn't exist already."""
794    if not self.fields[key].HasField('struct_value'):
795      # Clear will mark struct_value modified which will indeed create a struct.
796      self.fields[key].struct_value.Clear()
797    return self.fields[key].struct_value
798
799  def update(self, dictionary):  # pylint: disable=invalid-name
800    for key, value in dictionary.items():
801      _SetStructValue(self.fields[key], value)
802
803collections_abc.MutableMapping.register(Struct)
804
805
806class ListValue(object):
807  """Class for ListValue message type."""
808
809  __slots__ = ()
810
811  def __len__(self):
812    return len(self.values)
813
814  def append(self, value):
815    _SetStructValue(self.values.add(), value)
816
817  def extend(self, elem_seq):
818    for value in elem_seq:
819      self.append(value)
820
821  def __getitem__(self, index):
822    """Retrieves item by the specified index."""
823    return _GetStructValue(self.values.__getitem__(index))
824
825  def __setitem__(self, index, value):
826    _SetStructValue(self.values.__getitem__(index), value)
827
828  def __delitem__(self, key):
829    del self.values[key]
830
831  def items(self):
832    for i in range(len(self)):
833      yield self[i]
834
835  def add_struct(self):
836    """Appends and returns a struct value as the next value in the list."""
837    struct_value = self.values.add().struct_value
838    # Clear will mark struct_value modified which will indeed create a struct.
839    struct_value.Clear()
840    return struct_value
841
842  def add_list(self):
843    """Appends and returns a list value as the next value in the list."""
844    list_value = self.values.add().list_value
845    # Clear will mark list_value modified which will indeed create a list.
846    list_value.Clear()
847    return list_value
848
849collections_abc.MutableSequence.register(ListValue)
850
851
852WKTBASES = {
853    'google.protobuf.Any': Any,
854    'google.protobuf.Duration': Duration,
855    'google.protobuf.FieldMask': FieldMask,
856    'google.protobuf.ListValue': ListValue,
857    'google.protobuf.Struct': Struct,
858    'google.protobuf.Timestamp': Timestamp,
859}
860