1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Contains well known classes. 32 33This files defines well known classes which need extra maintenance including: 34 - Any 35 - Duration 36 - FieldMask 37 - Struct 38 - Timestamp 39""" 40 41__author__ = 'jieluo@google.com (Jie Luo)' 42 43import calendar 44from datetime import datetime 45from datetime import timedelta 46import six 47 48try: 49 # Since python 3 50 import collections.abc as collections_abc 51except ImportError: 52 # Won't work after python 3.8 53 import collections as collections_abc 54 55from google.protobuf.descriptor import FieldDescriptor 56 57_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S' 58_NANOS_PER_SECOND = 1000000000 59_NANOS_PER_MILLISECOND = 1000000 60_NANOS_PER_MICROSECOND = 1000 61_MILLIS_PER_SECOND = 1000 62_MICROS_PER_SECOND = 1000000 63_SECONDS_PER_DAY = 24 * 3600 64_DURATION_SECONDS_MAX = 315576000000 65 66 67class Any(object): 68 """Class for Any Message type.""" 69 70 __slots__ = () 71 72 def Pack(self, msg, type_url_prefix='type.googleapis.com/', 73 deterministic=None): 74 """Packs the specified message into current Any message.""" 75 if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/': 76 self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) 77 else: 78 self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) 79 self.value = msg.SerializeToString(deterministic=deterministic) 80 81 def Unpack(self, msg): 82 """Unpacks the current Any message into specified message.""" 83 descriptor = msg.DESCRIPTOR 84 if not self.Is(descriptor): 85 return False 86 msg.ParseFromString(self.value) 87 return True 88 89 def TypeName(self): 90 """Returns the protobuf type name of the inner message.""" 91 # Only last part is to be used: b/25630112 92 return self.type_url.split('/')[-1] 93 94 def Is(self, descriptor): 95 """Checks if this Any represents the given protobuf type.""" 96 return '/' in self.type_url and self.TypeName() == descriptor.full_name 97 98 99_EPOCH_DATETIME = datetime.utcfromtimestamp(0) 100 101 102class Timestamp(object): 103 """Class for Timestamp message type.""" 104 105 __slots__ = () 106 107 def ToJsonString(self): 108 """Converts Timestamp to RFC 3339 date string format. 109 110 Returns: 111 A string converted from timestamp. The string is always Z-normalized 112 and uses 3, 6 or 9 fractional digits as required to represent the 113 exact time. Example of the return format: '1972-01-01T10:00:20.021Z' 114 """ 115 nanos = self.nanos % _NANOS_PER_SECOND 116 total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND 117 seconds = total_sec % _SECONDS_PER_DAY 118 days = (total_sec - seconds) // _SECONDS_PER_DAY 119 dt = datetime(1970, 1, 1) + timedelta(days, seconds) 120 121 result = dt.isoformat() 122 if (nanos % 1e9) == 0: 123 # If there are 0 fractional digits, the fractional 124 # point '.' should be omitted when serializing. 125 return result + 'Z' 126 if (nanos % 1e6) == 0: 127 # Serialize 3 fractional digits. 128 return result + '.%03dZ' % (nanos / 1e6) 129 if (nanos % 1e3) == 0: 130 # Serialize 6 fractional digits. 131 return result + '.%06dZ' % (nanos / 1e3) 132 # Serialize 9 fractional digits. 133 return result + '.%09dZ' % nanos 134 135 def FromJsonString(self, value): 136 """Parse a RFC 3339 date string format to Timestamp. 137 138 Args: 139 value: A date string. Any fractional digits (or none) and any offset are 140 accepted as long as they fit into nano-seconds precision. 141 Example of accepted format: '1972-01-01T10:00:20.021-05:00' 142 143 Raises: 144 ValueError: On parsing problems. 145 """ 146 timezone_offset = value.find('Z') 147 if timezone_offset == -1: 148 timezone_offset = value.find('+') 149 if timezone_offset == -1: 150 timezone_offset = value.rfind('-') 151 if timezone_offset == -1: 152 raise ValueError( 153 'Failed to parse timestamp: missing valid timezone offset.') 154 time_value = value[0:timezone_offset] 155 # Parse datetime and nanos. 156 point_position = time_value.find('.') 157 if point_position == -1: 158 second_value = time_value 159 nano_value = '' 160 else: 161 second_value = time_value[:point_position] 162 nano_value = time_value[point_position + 1:] 163 if 't' in second_value: 164 raise ValueError( 165 'time data \'{0}\' does not match format \'%Y-%m-%dT%H:%M:%S\', ' 166 'lowercase \'t\' is not accepted'.format(second_value)) 167 date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT) 168 td = date_object - datetime(1970, 1, 1) 169 seconds = td.seconds + td.days * _SECONDS_PER_DAY 170 if len(nano_value) > 9: 171 raise ValueError( 172 'Failed to parse Timestamp: nanos {0} more than ' 173 '9 fractional digits.'.format(nano_value)) 174 if nano_value: 175 nanos = round(float('0.' + nano_value) * 1e9) 176 else: 177 nanos = 0 178 # Parse timezone offsets. 179 if value[timezone_offset] == 'Z': 180 if len(value) != timezone_offset + 1: 181 raise ValueError('Failed to parse timestamp: invalid trailing' 182 ' data {0}.'.format(value)) 183 else: 184 timezone = value[timezone_offset:] 185 pos = timezone.find(':') 186 if pos == -1: 187 raise ValueError( 188 'Invalid timezone offset value: {0}.'.format(timezone)) 189 if timezone[0] == '+': 190 seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 191 else: 192 seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 193 # Set seconds and nanos 194 self.seconds = int(seconds) 195 self.nanos = int(nanos) 196 197 def GetCurrentTime(self): 198 """Get the current UTC into Timestamp.""" 199 self.FromDatetime(datetime.utcnow()) 200 201 def ToNanoseconds(self): 202 """Converts Timestamp to nanoseconds since epoch.""" 203 return self.seconds * _NANOS_PER_SECOND + self.nanos 204 205 def ToMicroseconds(self): 206 """Converts Timestamp to microseconds since epoch.""" 207 return (self.seconds * _MICROS_PER_SECOND + 208 self.nanos // _NANOS_PER_MICROSECOND) 209 210 def ToMilliseconds(self): 211 """Converts Timestamp to milliseconds since epoch.""" 212 return (self.seconds * _MILLIS_PER_SECOND + 213 self.nanos // _NANOS_PER_MILLISECOND) 214 215 def ToSeconds(self): 216 """Converts Timestamp to seconds since epoch.""" 217 return self.seconds 218 219 def FromNanoseconds(self, nanos): 220 """Converts nanoseconds since epoch to Timestamp.""" 221 self.seconds = nanos // _NANOS_PER_SECOND 222 self.nanos = nanos % _NANOS_PER_SECOND 223 224 def FromMicroseconds(self, micros): 225 """Converts microseconds since epoch to Timestamp.""" 226 self.seconds = micros // _MICROS_PER_SECOND 227 self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND 228 229 def FromMilliseconds(self, millis): 230 """Converts milliseconds since epoch to Timestamp.""" 231 self.seconds = millis // _MILLIS_PER_SECOND 232 self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND 233 234 def FromSeconds(self, seconds): 235 """Converts seconds since epoch to Timestamp.""" 236 self.seconds = seconds 237 self.nanos = 0 238 239 def ToDatetime(self): 240 """Converts Timestamp to datetime.""" 241 return _EPOCH_DATETIME + timedelta( 242 seconds=self.seconds, microseconds=_RoundTowardZero( 243 self.nanos, _NANOS_PER_MICROSECOND)) 244 245 def FromDatetime(self, dt): 246 """Converts datetime to Timestamp.""" 247 # Using this guide: http://wiki.python.org/moin/WorkingWithTime 248 # And this conversion guide: http://docs.python.org/library/time.html 249 250 # Turn the date parameter into a tuple (struct_time) that can then be 251 # manipulated into a long value of seconds. During the conversion from 252 # struct_time to long, the source date in UTC, and so it follows that the 253 # correct transformation is calendar.timegm() 254 self.seconds = calendar.timegm(dt.utctimetuple()) 255 self.nanos = dt.microsecond * _NANOS_PER_MICROSECOND 256 257 258class Duration(object): 259 """Class for Duration message type.""" 260 261 __slots__ = () 262 263 def ToJsonString(self): 264 """Converts Duration to string format. 265 266 Returns: 267 A string converted from self. The string format will contains 268 3, 6, or 9 fractional digits depending on the precision required to 269 represent the exact Duration value. For example: "1s", "1.010s", 270 "1.000000100s", "-3.100s" 271 """ 272 _CheckDurationValid(self.seconds, self.nanos) 273 if self.seconds < 0 or self.nanos < 0: 274 result = '-' 275 seconds = - self.seconds + int((0 - self.nanos) // 1e9) 276 nanos = (0 - self.nanos) % 1e9 277 else: 278 result = '' 279 seconds = self.seconds + int(self.nanos // 1e9) 280 nanos = self.nanos % 1e9 281 result += '%d' % seconds 282 if (nanos % 1e9) == 0: 283 # If there are 0 fractional digits, the fractional 284 # point '.' should be omitted when serializing. 285 return result + 's' 286 if (nanos % 1e6) == 0: 287 # Serialize 3 fractional digits. 288 return result + '.%03ds' % (nanos / 1e6) 289 if (nanos % 1e3) == 0: 290 # Serialize 6 fractional digits. 291 return result + '.%06ds' % (nanos / 1e3) 292 # Serialize 9 fractional digits. 293 return result + '.%09ds' % nanos 294 295 def FromJsonString(self, value): 296 """Converts a string to Duration. 297 298 Args: 299 value: A string to be converted. The string must end with 's'. Any 300 fractional digits (or none) are accepted as long as they fit into 301 precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s 302 303 Raises: 304 ValueError: On parsing problems. 305 """ 306 if len(value) < 1 or value[-1] != 's': 307 raise ValueError( 308 'Duration must end with letter "s": {0}.'.format(value)) 309 try: 310 pos = value.find('.') 311 if pos == -1: 312 seconds = int(value[:-1]) 313 nanos = 0 314 else: 315 seconds = int(value[:pos]) 316 if value[0] == '-': 317 nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9)) 318 else: 319 nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9)) 320 _CheckDurationValid(seconds, nanos) 321 self.seconds = seconds 322 self.nanos = nanos 323 except ValueError as e: 324 raise ValueError( 325 'Couldn\'t parse duration: {0} : {1}.'.format(value, e)) 326 327 def ToNanoseconds(self): 328 """Converts a Duration to nanoseconds.""" 329 return self.seconds * _NANOS_PER_SECOND + self.nanos 330 331 def ToMicroseconds(self): 332 """Converts a Duration to microseconds.""" 333 micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND) 334 return self.seconds * _MICROS_PER_SECOND + micros 335 336 def ToMilliseconds(self): 337 """Converts a Duration to milliseconds.""" 338 millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND) 339 return self.seconds * _MILLIS_PER_SECOND + millis 340 341 def ToSeconds(self): 342 """Converts a Duration to seconds.""" 343 return self.seconds 344 345 def FromNanoseconds(self, nanos): 346 """Converts nanoseconds to Duration.""" 347 self._NormalizeDuration(nanos // _NANOS_PER_SECOND, 348 nanos % _NANOS_PER_SECOND) 349 350 def FromMicroseconds(self, micros): 351 """Converts microseconds to Duration.""" 352 self._NormalizeDuration( 353 micros // _MICROS_PER_SECOND, 354 (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND) 355 356 def FromMilliseconds(self, millis): 357 """Converts milliseconds to Duration.""" 358 self._NormalizeDuration( 359 millis // _MILLIS_PER_SECOND, 360 (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND) 361 362 def FromSeconds(self, seconds): 363 """Converts seconds to Duration.""" 364 self.seconds = seconds 365 self.nanos = 0 366 367 def ToTimedelta(self): 368 """Converts Duration to timedelta.""" 369 return timedelta( 370 seconds=self.seconds, microseconds=_RoundTowardZero( 371 self.nanos, _NANOS_PER_MICROSECOND)) 372 373 def FromTimedelta(self, td): 374 """Converts timedelta to Duration.""" 375 self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY, 376 td.microseconds * _NANOS_PER_MICROSECOND) 377 378 def _NormalizeDuration(self, seconds, nanos): 379 """Set Duration by seconds and nanos.""" 380 # Force nanos to be negative if the duration is negative. 381 if seconds < 0 and nanos > 0: 382 seconds += 1 383 nanos -= _NANOS_PER_SECOND 384 self.seconds = seconds 385 self.nanos = nanos 386 387 388def _CheckDurationValid(seconds, nanos): 389 if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX: 390 raise ValueError( 391 'Duration is not valid: Seconds {0} must be in range ' 392 '[-315576000000, 315576000000].'.format(seconds)) 393 if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND: 394 raise ValueError( 395 'Duration is not valid: Nanos {0} must be in range ' 396 '[-999999999, 999999999].'.format(nanos)) 397 if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0): 398 raise ValueError( 399 'Duration is not valid: Sign mismatch.') 400 401 402def _RoundTowardZero(value, divider): 403 """Truncates the remainder part after division.""" 404 # For some languanges, the sign of the remainder is implementation 405 # dependent if any of the operands is negative. Here we enforce 406 # "rounded toward zero" semantics. For example, for (-5) / 2 an 407 # implementation may give -3 as the result with the remainder being 408 # 1. This function ensures we always return -2 (closer to zero). 409 result = value // divider 410 remainder = value % divider 411 if result < 0 and remainder > 0: 412 return result + 1 413 else: 414 return result 415 416 417class FieldMask(object): 418 """Class for FieldMask message type.""" 419 420 __slots__ = () 421 422 def ToJsonString(self): 423 """Converts FieldMask to string according to proto3 JSON spec.""" 424 camelcase_paths = [] 425 for path in self.paths: 426 camelcase_paths.append(_SnakeCaseToCamelCase(path)) 427 return ','.join(camelcase_paths) 428 429 def FromJsonString(self, value): 430 """Converts string to FieldMask according to proto3 JSON spec.""" 431 self.Clear() 432 if value: 433 for path in value.split(','): 434 self.paths.append(_CamelCaseToSnakeCase(path)) 435 436 def IsValidForDescriptor(self, message_descriptor): 437 """Checks whether the FieldMask is valid for Message Descriptor.""" 438 for path in self.paths: 439 if not _IsValidPath(message_descriptor, path): 440 return False 441 return True 442 443 def AllFieldsFromDescriptor(self, message_descriptor): 444 """Gets all direct fields of Message Descriptor to FieldMask.""" 445 self.Clear() 446 for field in message_descriptor.fields: 447 self.paths.append(field.name) 448 449 def CanonicalFormFromMask(self, mask): 450 """Converts a FieldMask to the canonical form. 451 452 Removes paths that are covered by another path. For example, 453 "foo.bar" is covered by "foo" and will be removed if "foo" 454 is also in the FieldMask. Then sorts all paths in alphabetical order. 455 456 Args: 457 mask: The original FieldMask to be converted. 458 """ 459 tree = _FieldMaskTree(mask) 460 tree.ToFieldMask(self) 461 462 def Union(self, mask1, mask2): 463 """Merges mask1 and mask2 into this FieldMask.""" 464 _CheckFieldMaskMessage(mask1) 465 _CheckFieldMaskMessage(mask2) 466 tree = _FieldMaskTree(mask1) 467 tree.MergeFromFieldMask(mask2) 468 tree.ToFieldMask(self) 469 470 def Intersect(self, mask1, mask2): 471 """Intersects mask1 and mask2 into this FieldMask.""" 472 _CheckFieldMaskMessage(mask1) 473 _CheckFieldMaskMessage(mask2) 474 tree = _FieldMaskTree(mask1) 475 intersection = _FieldMaskTree() 476 for path in mask2.paths: 477 tree.IntersectPath(path, intersection) 478 intersection.ToFieldMask(self) 479 480 def MergeMessage( 481 self, source, destination, 482 replace_message_field=False, replace_repeated_field=False): 483 """Merges fields specified in FieldMask from source to destination. 484 485 Args: 486 source: Source message. 487 destination: The destination message to be merged into. 488 replace_message_field: Replace message field if True. Merge message 489 field if False. 490 replace_repeated_field: Replace repeated field if True. Append 491 elements of repeated field if False. 492 """ 493 tree = _FieldMaskTree(self) 494 tree.MergeMessage( 495 source, destination, replace_message_field, replace_repeated_field) 496 497 498def _IsValidPath(message_descriptor, path): 499 """Checks whether the path is valid for Message Descriptor.""" 500 parts = path.split('.') 501 last = parts.pop() 502 for name in parts: 503 field = message_descriptor.fields_by_name.get(name) 504 if (field is None or 505 field.label == FieldDescriptor.LABEL_REPEATED or 506 field.type != FieldDescriptor.TYPE_MESSAGE): 507 return False 508 message_descriptor = field.message_type 509 return last in message_descriptor.fields_by_name 510 511 512def _CheckFieldMaskMessage(message): 513 """Raises ValueError if message is not a FieldMask.""" 514 message_descriptor = message.DESCRIPTOR 515 if (message_descriptor.name != 'FieldMask' or 516 message_descriptor.file.name != 'google/protobuf/field_mask.proto'): 517 raise ValueError('Message {0} is not a FieldMask.'.format( 518 message_descriptor.full_name)) 519 520 521def _SnakeCaseToCamelCase(path_name): 522 """Converts a path name from snake_case to camelCase.""" 523 result = [] 524 after_underscore = False 525 for c in path_name: 526 if c.isupper(): 527 raise ValueError( 528 'Fail to print FieldMask to Json string: Path name ' 529 '{0} must not contain uppercase letters.'.format(path_name)) 530 if after_underscore: 531 if c.islower(): 532 result.append(c.upper()) 533 after_underscore = False 534 else: 535 raise ValueError( 536 'Fail to print FieldMask to Json string: The ' 537 'character after a "_" must be a lowercase letter ' 538 'in path name {0}.'.format(path_name)) 539 elif c == '_': 540 after_underscore = True 541 else: 542 result += c 543 544 if after_underscore: 545 raise ValueError('Fail to print FieldMask to Json string: Trailing "_" ' 546 'in path name {0}.'.format(path_name)) 547 return ''.join(result) 548 549 550def _CamelCaseToSnakeCase(path_name): 551 """Converts a field name from camelCase to snake_case.""" 552 result = [] 553 for c in path_name: 554 if c == '_': 555 raise ValueError('Fail to parse FieldMask: Path name ' 556 '{0} must not contain "_"s.'.format(path_name)) 557 if c.isupper(): 558 result += '_' 559 result += c.lower() 560 else: 561 result += c 562 return ''.join(result) 563 564 565class _FieldMaskTree(object): 566 """Represents a FieldMask in a tree structure. 567 568 For example, given a FieldMask "foo.bar,foo.baz,bar.baz", 569 the FieldMaskTree will be: 570 [_root] -+- foo -+- bar 571 | | 572 | +- baz 573 | 574 +- bar --- baz 575 In the tree, each leaf node represents a field path. 576 """ 577 578 __slots__ = ('_root',) 579 580 def __init__(self, field_mask=None): 581 """Initializes the tree by FieldMask.""" 582 self._root = {} 583 if field_mask: 584 self.MergeFromFieldMask(field_mask) 585 586 def MergeFromFieldMask(self, field_mask): 587 """Merges a FieldMask to the tree.""" 588 for path in field_mask.paths: 589 self.AddPath(path) 590 591 def AddPath(self, path): 592 """Adds a field path into the tree. 593 594 If the field path to add is a sub-path of an existing field path 595 in the tree (i.e., a leaf node), it means the tree already matches 596 the given path so nothing will be added to the tree. If the path 597 matches an existing non-leaf node in the tree, that non-leaf node 598 will be turned into a leaf node with all its children removed because 599 the path matches all the node's children. Otherwise, a new path will 600 be added. 601 602 Args: 603 path: The field path to add. 604 """ 605 node = self._root 606 for name in path.split('.'): 607 if name not in node: 608 node[name] = {} 609 elif not node[name]: 610 # Pre-existing empty node implies we already have this entire tree. 611 return 612 node = node[name] 613 # Remove any sub-trees we might have had. 614 node.clear() 615 616 def ToFieldMask(self, field_mask): 617 """Converts the tree to a FieldMask.""" 618 field_mask.Clear() 619 _AddFieldPaths(self._root, '', field_mask) 620 621 def IntersectPath(self, path, intersection): 622 """Calculates the intersection part of a field path with this tree. 623 624 Args: 625 path: The field path to calculates. 626 intersection: The out tree to record the intersection part. 627 """ 628 node = self._root 629 for name in path.split('.'): 630 if name not in node: 631 return 632 elif not node[name]: 633 intersection.AddPath(path) 634 return 635 node = node[name] 636 intersection.AddLeafNodes(path, node) 637 638 def AddLeafNodes(self, prefix, node): 639 """Adds leaf nodes begin with prefix to this tree.""" 640 if not node: 641 self.AddPath(prefix) 642 for name in node: 643 child_path = prefix + '.' + name 644 self.AddLeafNodes(child_path, node[name]) 645 646 def MergeMessage( 647 self, source, destination, 648 replace_message, replace_repeated): 649 """Merge all fields specified by this tree from source to destination.""" 650 _MergeMessage( 651 self._root, source, destination, replace_message, replace_repeated) 652 653 654def _StrConvert(value): 655 """Converts value to str if it is not.""" 656 # This file is imported by c extension and some methods like ClearField 657 # requires string for the field name. py2/py3 has different text 658 # type and may use unicode. 659 if not isinstance(value, str): 660 return value.encode('utf-8') 661 return value 662 663 664def _MergeMessage( 665 node, source, destination, replace_message, replace_repeated): 666 """Merge all fields specified by a sub-tree from source to destination.""" 667 source_descriptor = source.DESCRIPTOR 668 for name in node: 669 child = node[name] 670 field = source_descriptor.fields_by_name[name] 671 if field is None: 672 raise ValueError('Error: Can\'t find field {0} in message {1}.'.format( 673 name, source_descriptor.full_name)) 674 if child: 675 # Sub-paths are only allowed for singular message fields. 676 if (field.label == FieldDescriptor.LABEL_REPEATED or 677 field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE): 678 raise ValueError('Error: Field {0} in message {1} is not a singular ' 679 'message field and cannot have sub-fields.'.format( 680 name, source_descriptor.full_name)) 681 if source.HasField(name): 682 _MergeMessage( 683 child, getattr(source, name), getattr(destination, name), 684 replace_message, replace_repeated) 685 continue 686 if field.label == FieldDescriptor.LABEL_REPEATED: 687 if replace_repeated: 688 destination.ClearField(_StrConvert(name)) 689 repeated_source = getattr(source, name) 690 repeated_destination = getattr(destination, name) 691 repeated_destination.MergeFrom(repeated_source) 692 else: 693 if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: 694 if replace_message: 695 destination.ClearField(_StrConvert(name)) 696 if source.HasField(name): 697 getattr(destination, name).MergeFrom(getattr(source, name)) 698 else: 699 setattr(destination, name, getattr(source, name)) 700 701 702def _AddFieldPaths(node, prefix, field_mask): 703 """Adds the field paths descended from node to field_mask.""" 704 if not node and prefix: 705 field_mask.paths.append(prefix) 706 return 707 for name in sorted(node): 708 if prefix: 709 child_path = prefix + '.' + name 710 else: 711 child_path = name 712 _AddFieldPaths(node[name], child_path, field_mask) 713 714 715_INT_OR_FLOAT = six.integer_types + (float,) 716 717 718def _SetStructValue(struct_value, value): 719 if value is None: 720 struct_value.null_value = 0 721 elif isinstance(value, bool): 722 # Note: this check must come before the number check because in Python 723 # True and False are also considered numbers. 724 struct_value.bool_value = value 725 elif isinstance(value, six.string_types): 726 struct_value.string_value = value 727 elif isinstance(value, _INT_OR_FLOAT): 728 struct_value.number_value = value 729 elif isinstance(value, (dict, Struct)): 730 struct_value.struct_value.Clear() 731 struct_value.struct_value.update(value) 732 elif isinstance(value, (list, ListValue)): 733 struct_value.list_value.Clear() 734 struct_value.list_value.extend(value) 735 else: 736 raise ValueError('Unexpected type') 737 738 739def _GetStructValue(struct_value): 740 which = struct_value.WhichOneof('kind') 741 if which == 'struct_value': 742 return struct_value.struct_value 743 elif which == 'null_value': 744 return None 745 elif which == 'number_value': 746 return struct_value.number_value 747 elif which == 'string_value': 748 return struct_value.string_value 749 elif which == 'bool_value': 750 return struct_value.bool_value 751 elif which == 'list_value': 752 return struct_value.list_value 753 elif which is None: 754 raise ValueError('Value not set') 755 756 757class Struct(object): 758 """Class for Struct message type.""" 759 760 __slots__ = () 761 762 def __getitem__(self, key): 763 return _GetStructValue(self.fields[key]) 764 765 def __contains__(self, item): 766 return item in self.fields 767 768 def __setitem__(self, key, value): 769 _SetStructValue(self.fields[key], value) 770 771 def __delitem__(self, key): 772 del self.fields[key] 773 774 def __len__(self): 775 return len(self.fields) 776 777 def __iter__(self): 778 return iter(self.fields) 779 780 def keys(self): # pylint: disable=invalid-name 781 return self.fields.keys() 782 783 def values(self): # pylint: disable=invalid-name 784 return [self[key] for key in self] 785 786 def items(self): # pylint: disable=invalid-name 787 return [(key, self[key]) for key in self] 788 789 def get_or_create_list(self, key): 790 """Returns a list for this key, creating if it didn't exist already.""" 791 if not self.fields[key].HasField('list_value'): 792 # Clear will mark list_value modified which will indeed create a list. 793 self.fields[key].list_value.Clear() 794 return self.fields[key].list_value 795 796 def get_or_create_struct(self, key): 797 """Returns a struct for this key, creating if it didn't exist already.""" 798 if not self.fields[key].HasField('struct_value'): 799 # Clear will mark struct_value modified which will indeed create a struct. 800 self.fields[key].struct_value.Clear() 801 return self.fields[key].struct_value 802 803 def update(self, dictionary): # pylint: disable=invalid-name 804 for key, value in dictionary.items(): 805 _SetStructValue(self.fields[key], value) 806 807collections_abc.MutableMapping.register(Struct) 808 809 810class ListValue(object): 811 """Class for ListValue message type.""" 812 813 __slots__ = () 814 815 def __len__(self): 816 return len(self.values) 817 818 def append(self, value): 819 _SetStructValue(self.values.add(), value) 820 821 def extend(self, elem_seq): 822 for value in elem_seq: 823 self.append(value) 824 825 def __getitem__(self, index): 826 """Retrieves item by the specified index.""" 827 return _GetStructValue(self.values.__getitem__(index)) 828 829 def __setitem__(self, index, value): 830 _SetStructValue(self.values.__getitem__(index), value) 831 832 def __delitem__(self, key): 833 del self.values[key] 834 835 def items(self): 836 for i in range(len(self)): 837 yield self[i] 838 839 def add_struct(self): 840 """Appends and returns a struct value as the next value in the list.""" 841 struct_value = self.values.add().struct_value 842 # Clear will mark struct_value modified which will indeed create a struct. 843 struct_value.Clear() 844 return struct_value 845 846 def add_list(self): 847 """Appends and returns a list value as the next value in the list.""" 848 list_value = self.values.add().list_value 849 # Clear will mark list_value modified which will indeed create a list. 850 list_value.Clear() 851 return list_value 852 853collections_abc.MutableSequence.register(ListValue) 854 855 856WKTBASES = { 857 'google.protobuf.Any': Any, 858 'google.protobuf.Duration': Duration, 859 'google.protobuf.FieldMask': FieldMask, 860 'google.protobuf.ListValue': ListValue, 861 'google.protobuf.Struct': Struct, 862 'google.protobuf.Timestamp': Timestamp, 863} 864