1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Contains well known classes. 32 33This files defines well known classes which need extra maintenance including: 34 - Any 35 - Duration 36 - FieldMask 37 - Struct 38 - Timestamp 39""" 40 41__author__ = 'jieluo@google.com (Jie Luo)' 42 43import calendar 44from datetime import datetime 45from datetime import timedelta 46import six 47 48try: 49 # Since python 3 50 import collections.abc as collections_abc 51except ImportError: 52 # Won't work after python 3.8 53 import collections as collections_abc 54 55from google.protobuf.descriptor import FieldDescriptor 56 57_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S' 58_NANOS_PER_SECOND = 1000000000 59_NANOS_PER_MILLISECOND = 1000000 60_NANOS_PER_MICROSECOND = 1000 61_MILLIS_PER_SECOND = 1000 62_MICROS_PER_SECOND = 1000000 63_SECONDS_PER_DAY = 24 * 3600 64_DURATION_SECONDS_MAX = 315576000000 65 66 67class Any(object): 68 """Class for Any Message type.""" 69 70 __slots__ = () 71 72 def Pack(self, msg, type_url_prefix='type.googleapis.com/', 73 deterministic=None): 74 """Packs the specified message into current Any message.""" 75 if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/': 76 self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) 77 else: 78 self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) 79 self.value = msg.SerializeToString(deterministic=deterministic) 80 81 def Unpack(self, msg): 82 """Unpacks the current Any message into specified message.""" 83 descriptor = msg.DESCRIPTOR 84 if not self.Is(descriptor): 85 return False 86 msg.ParseFromString(self.value) 87 return True 88 89 def TypeName(self): 90 """Returns the protobuf type name of the inner message.""" 91 # Only last part is to be used: b/25630112 92 return self.type_url.split('/')[-1] 93 94 def Is(self, descriptor): 95 """Checks if this Any represents the given protobuf type.""" 96 return '/' in self.type_url and self.TypeName() == descriptor.full_name 97 98 99_EPOCH_DATETIME = datetime.utcfromtimestamp(0) 100 101 102class Timestamp(object): 103 """Class for Timestamp message type.""" 104 105 __slots__ = () 106 107 def ToJsonString(self): 108 """Converts Timestamp to RFC 3339 date string format. 109 110 Returns: 111 A string converted from timestamp. The string is always Z-normalized 112 and uses 3, 6 or 9 fractional digits as required to represent the 113 exact time. Example of the return format: '1972-01-01T10:00:20.021Z' 114 """ 115 nanos = self.nanos % _NANOS_PER_SECOND 116 total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND 117 seconds = total_sec % _SECONDS_PER_DAY 118 days = (total_sec - seconds) // _SECONDS_PER_DAY 119 dt = datetime(1970, 1, 1) + timedelta(days, seconds) 120 121 result = dt.isoformat() 122 if (nanos % 1e9) == 0: 123 # If there are 0 fractional digits, the fractional 124 # point '.' should be omitted when serializing. 125 return result + 'Z' 126 if (nanos % 1e6) == 0: 127 # Serialize 3 fractional digits. 128 return result + '.%03dZ' % (nanos / 1e6) 129 if (nanos % 1e3) == 0: 130 # Serialize 6 fractional digits. 131 return result + '.%06dZ' % (nanos / 1e3) 132 # Serialize 9 fractional digits. 133 return result + '.%09dZ' % nanos 134 135 def FromJsonString(self, value): 136 """Parse a RFC 3339 date string format to Timestamp. 137 138 Args: 139 value: A date string. Any fractional digits (or none) and any offset are 140 accepted as long as they fit into nano-seconds precision. 141 Example of accepted format: '1972-01-01T10:00:20.021-05:00' 142 143 Raises: 144 ValueError: On parsing problems. 145 """ 146 timezone_offset = value.find('Z') 147 if timezone_offset == -1: 148 timezone_offset = value.find('+') 149 if timezone_offset == -1: 150 timezone_offset = value.rfind('-') 151 if timezone_offset == -1: 152 raise ValueError( 153 'Failed to parse timestamp: missing valid timezone offset.') 154 time_value = value[0:timezone_offset] 155 # Parse datetime and nanos. 156 point_position = time_value.find('.') 157 if point_position == -1: 158 second_value = time_value 159 nano_value = '' 160 else: 161 second_value = time_value[:point_position] 162 nano_value = time_value[point_position + 1:] 163 date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT) 164 td = date_object - datetime(1970, 1, 1) 165 seconds = td.seconds + td.days * _SECONDS_PER_DAY 166 if len(nano_value) > 9: 167 raise ValueError( 168 'Failed to parse Timestamp: nanos {0} more than ' 169 '9 fractional digits.'.format(nano_value)) 170 if nano_value: 171 nanos = round(float('0.' + nano_value) * 1e9) 172 else: 173 nanos = 0 174 # Parse timezone offsets. 175 if value[timezone_offset] == 'Z': 176 if len(value) != timezone_offset + 1: 177 raise ValueError('Failed to parse timestamp: invalid trailing' 178 ' data {0}.'.format(value)) 179 else: 180 timezone = value[timezone_offset:] 181 pos = timezone.find(':') 182 if pos == -1: 183 raise ValueError( 184 'Invalid timezone offset value: {0}.'.format(timezone)) 185 if timezone[0] == '+': 186 seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 187 else: 188 seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 189 # Set seconds and nanos 190 self.seconds = int(seconds) 191 self.nanos = int(nanos) 192 193 def GetCurrentTime(self): 194 """Get the current UTC into Timestamp.""" 195 self.FromDatetime(datetime.utcnow()) 196 197 def ToNanoseconds(self): 198 """Converts Timestamp to nanoseconds since epoch.""" 199 return self.seconds * _NANOS_PER_SECOND + self.nanos 200 201 def ToMicroseconds(self): 202 """Converts Timestamp to microseconds since epoch.""" 203 return (self.seconds * _MICROS_PER_SECOND + 204 self.nanos // _NANOS_PER_MICROSECOND) 205 206 def ToMilliseconds(self): 207 """Converts Timestamp to milliseconds since epoch.""" 208 return (self.seconds * _MILLIS_PER_SECOND + 209 self.nanos // _NANOS_PER_MILLISECOND) 210 211 def ToSeconds(self): 212 """Converts Timestamp to seconds since epoch.""" 213 return self.seconds 214 215 def FromNanoseconds(self, nanos): 216 """Converts nanoseconds since epoch to Timestamp.""" 217 self.seconds = nanos // _NANOS_PER_SECOND 218 self.nanos = nanos % _NANOS_PER_SECOND 219 220 def FromMicroseconds(self, micros): 221 """Converts microseconds since epoch to Timestamp.""" 222 self.seconds = micros // _MICROS_PER_SECOND 223 self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND 224 225 def FromMilliseconds(self, millis): 226 """Converts milliseconds since epoch to Timestamp.""" 227 self.seconds = millis // _MILLIS_PER_SECOND 228 self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND 229 230 def FromSeconds(self, seconds): 231 """Converts seconds since epoch to Timestamp.""" 232 self.seconds = seconds 233 self.nanos = 0 234 235 def ToDatetime(self): 236 """Converts Timestamp to datetime.""" 237 return _EPOCH_DATETIME + timedelta( 238 seconds=self.seconds, microseconds=_RoundTowardZero( 239 self.nanos, _NANOS_PER_MICROSECOND)) 240 241 def FromDatetime(self, dt): 242 """Converts datetime to Timestamp.""" 243 # Using this guide: http://wiki.python.org/moin/WorkingWithTime 244 # And this conversion guide: http://docs.python.org/library/time.html 245 246 # Turn the date parameter into a tuple (struct_time) that can then be 247 # manipulated into a long value of seconds. During the conversion from 248 # struct_time to long, the source date in UTC, and so it follows that the 249 # correct transformation is calendar.timegm() 250 self.seconds = calendar.timegm(dt.utctimetuple()) 251 self.nanos = dt.microsecond * _NANOS_PER_MICROSECOND 252 253 254class Duration(object): 255 """Class for Duration message type.""" 256 257 __slots__ = () 258 259 def ToJsonString(self): 260 """Converts Duration to string format. 261 262 Returns: 263 A string converted from self. The string format will contains 264 3, 6, or 9 fractional digits depending on the precision required to 265 represent the exact Duration value. For example: "1s", "1.010s", 266 "1.000000100s", "-3.100s" 267 """ 268 _CheckDurationValid(self.seconds, self.nanos) 269 if self.seconds < 0 or self.nanos < 0: 270 result = '-' 271 seconds = - self.seconds + int((0 - self.nanos) // 1e9) 272 nanos = (0 - self.nanos) % 1e9 273 else: 274 result = '' 275 seconds = self.seconds + int(self.nanos // 1e9) 276 nanos = self.nanos % 1e9 277 result += '%d' % seconds 278 if (nanos % 1e9) == 0: 279 # If there are 0 fractional digits, the fractional 280 # point '.' should be omitted when serializing. 281 return result + 's' 282 if (nanos % 1e6) == 0: 283 # Serialize 3 fractional digits. 284 return result + '.%03ds' % (nanos / 1e6) 285 if (nanos % 1e3) == 0: 286 # Serialize 6 fractional digits. 287 return result + '.%06ds' % (nanos / 1e3) 288 # Serialize 9 fractional digits. 289 return result + '.%09ds' % nanos 290 291 def FromJsonString(self, value): 292 """Converts a string to Duration. 293 294 Args: 295 value: A string to be converted. The string must end with 's'. Any 296 fractional digits (or none) are accepted as long as they fit into 297 precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s 298 299 Raises: 300 ValueError: On parsing problems. 301 """ 302 if len(value) < 1 or value[-1] != 's': 303 raise ValueError( 304 'Duration must end with letter "s": {0}.'.format(value)) 305 try: 306 pos = value.find('.') 307 if pos == -1: 308 seconds = int(value[:-1]) 309 nanos = 0 310 else: 311 seconds = int(value[:pos]) 312 if value[0] == '-': 313 nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9)) 314 else: 315 nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9)) 316 _CheckDurationValid(seconds, nanos) 317 self.seconds = seconds 318 self.nanos = nanos 319 except ValueError as e: 320 raise ValueError( 321 'Couldn\'t parse duration: {0} : {1}.'.format(value, e)) 322 323 def ToNanoseconds(self): 324 """Converts a Duration to nanoseconds.""" 325 return self.seconds * _NANOS_PER_SECOND + self.nanos 326 327 def ToMicroseconds(self): 328 """Converts a Duration to microseconds.""" 329 micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND) 330 return self.seconds * _MICROS_PER_SECOND + micros 331 332 def ToMilliseconds(self): 333 """Converts a Duration to milliseconds.""" 334 millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND) 335 return self.seconds * _MILLIS_PER_SECOND + millis 336 337 def ToSeconds(self): 338 """Converts a Duration to seconds.""" 339 return self.seconds 340 341 def FromNanoseconds(self, nanos): 342 """Converts nanoseconds to Duration.""" 343 self._NormalizeDuration(nanos // _NANOS_PER_SECOND, 344 nanos % _NANOS_PER_SECOND) 345 346 def FromMicroseconds(self, micros): 347 """Converts microseconds to Duration.""" 348 self._NormalizeDuration( 349 micros // _MICROS_PER_SECOND, 350 (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND) 351 352 def FromMilliseconds(self, millis): 353 """Converts milliseconds to Duration.""" 354 self._NormalizeDuration( 355 millis // _MILLIS_PER_SECOND, 356 (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND) 357 358 def FromSeconds(self, seconds): 359 """Converts seconds to Duration.""" 360 self.seconds = seconds 361 self.nanos = 0 362 363 def ToTimedelta(self): 364 """Converts Duration to timedelta.""" 365 return timedelta( 366 seconds=self.seconds, microseconds=_RoundTowardZero( 367 self.nanos, _NANOS_PER_MICROSECOND)) 368 369 def FromTimedelta(self, td): 370 """Converts timedelta to Duration.""" 371 self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY, 372 td.microseconds * _NANOS_PER_MICROSECOND) 373 374 def _NormalizeDuration(self, seconds, nanos): 375 """Set Duration by seconds and nanos.""" 376 # Force nanos to be negative if the duration is negative. 377 if seconds < 0 and nanos > 0: 378 seconds += 1 379 nanos -= _NANOS_PER_SECOND 380 self.seconds = seconds 381 self.nanos = nanos 382 383 384def _CheckDurationValid(seconds, nanos): 385 if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX: 386 raise ValueError( 387 'Duration is not valid: Seconds {0} must be in range ' 388 '[-315576000000, 315576000000].'.format(seconds)) 389 if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND: 390 raise ValueError( 391 'Duration is not valid: Nanos {0} must be in range ' 392 '[-999999999, 999999999].'.format(nanos)) 393 if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0): 394 raise ValueError( 395 'Duration is not valid: Sign mismatch.') 396 397 398def _RoundTowardZero(value, divider): 399 """Truncates the remainder part after division.""" 400 # For some languanges, the sign of the remainder is implementation 401 # dependent if any of the operands is negative. Here we enforce 402 # "rounded toward zero" semantics. For example, for (-5) / 2 an 403 # implementation may give -3 as the result with the remainder being 404 # 1. This function ensures we always return -2 (closer to zero). 405 result = value // divider 406 remainder = value % divider 407 if result < 0 and remainder > 0: 408 return result + 1 409 else: 410 return result 411 412 413class FieldMask(object): 414 """Class for FieldMask message type.""" 415 416 __slots__ = () 417 418 def ToJsonString(self): 419 """Converts FieldMask to string according to proto3 JSON spec.""" 420 camelcase_paths = [] 421 for path in self.paths: 422 camelcase_paths.append(_SnakeCaseToCamelCase(path)) 423 return ','.join(camelcase_paths) 424 425 def FromJsonString(self, value): 426 """Converts string to FieldMask according to proto3 JSON spec.""" 427 self.Clear() 428 if value: 429 for path in value.split(','): 430 self.paths.append(_CamelCaseToSnakeCase(path)) 431 432 def IsValidForDescriptor(self, message_descriptor): 433 """Checks whether the FieldMask is valid for Message Descriptor.""" 434 for path in self.paths: 435 if not _IsValidPath(message_descriptor, path): 436 return False 437 return True 438 439 def AllFieldsFromDescriptor(self, message_descriptor): 440 """Gets all direct fields of Message Descriptor to FieldMask.""" 441 self.Clear() 442 for field in message_descriptor.fields: 443 self.paths.append(field.name) 444 445 def CanonicalFormFromMask(self, mask): 446 """Converts a FieldMask to the canonical form. 447 448 Removes paths that are covered by another path. For example, 449 "foo.bar" is covered by "foo" and will be removed if "foo" 450 is also in the FieldMask. Then sorts all paths in alphabetical order. 451 452 Args: 453 mask: The original FieldMask to be converted. 454 """ 455 tree = _FieldMaskTree(mask) 456 tree.ToFieldMask(self) 457 458 def Union(self, mask1, mask2): 459 """Merges mask1 and mask2 into this FieldMask.""" 460 _CheckFieldMaskMessage(mask1) 461 _CheckFieldMaskMessage(mask2) 462 tree = _FieldMaskTree(mask1) 463 tree.MergeFromFieldMask(mask2) 464 tree.ToFieldMask(self) 465 466 def Intersect(self, mask1, mask2): 467 """Intersects mask1 and mask2 into this FieldMask.""" 468 _CheckFieldMaskMessage(mask1) 469 _CheckFieldMaskMessage(mask2) 470 tree = _FieldMaskTree(mask1) 471 intersection = _FieldMaskTree() 472 for path in mask2.paths: 473 tree.IntersectPath(path, intersection) 474 intersection.ToFieldMask(self) 475 476 def MergeMessage( 477 self, source, destination, 478 replace_message_field=False, replace_repeated_field=False): 479 """Merges fields specified in FieldMask from source to destination. 480 481 Args: 482 source: Source message. 483 destination: The destination message to be merged into. 484 replace_message_field: Replace message field if True. Merge message 485 field if False. 486 replace_repeated_field: Replace repeated field if True. Append 487 elements of repeated field if False. 488 """ 489 tree = _FieldMaskTree(self) 490 tree.MergeMessage( 491 source, destination, replace_message_field, replace_repeated_field) 492 493 494def _IsValidPath(message_descriptor, path): 495 """Checks whether the path is valid for Message Descriptor.""" 496 parts = path.split('.') 497 last = parts.pop() 498 for name in parts: 499 field = message_descriptor.fields_by_name.get(name) 500 if (field is None or 501 field.label == FieldDescriptor.LABEL_REPEATED or 502 field.type != FieldDescriptor.TYPE_MESSAGE): 503 return False 504 message_descriptor = field.message_type 505 return last in message_descriptor.fields_by_name 506 507 508def _CheckFieldMaskMessage(message): 509 """Raises ValueError if message is not a FieldMask.""" 510 message_descriptor = message.DESCRIPTOR 511 if (message_descriptor.name != 'FieldMask' or 512 message_descriptor.file.name != 'google/protobuf/field_mask.proto'): 513 raise ValueError('Message {0} is not a FieldMask.'.format( 514 message_descriptor.full_name)) 515 516 517def _SnakeCaseToCamelCase(path_name): 518 """Converts a path name from snake_case to camelCase.""" 519 result = [] 520 after_underscore = False 521 for c in path_name: 522 if c.isupper(): 523 raise ValueError( 524 'Fail to print FieldMask to Json string: Path name ' 525 '{0} must not contain uppercase letters.'.format(path_name)) 526 if after_underscore: 527 if c.islower(): 528 result.append(c.upper()) 529 after_underscore = False 530 else: 531 raise ValueError( 532 'Fail to print FieldMask to Json string: The ' 533 'character after a "_" must be a lowercase letter ' 534 'in path name {0}.'.format(path_name)) 535 elif c == '_': 536 after_underscore = True 537 else: 538 result += c 539 540 if after_underscore: 541 raise ValueError('Fail to print FieldMask to Json string: Trailing "_" ' 542 'in path name {0}.'.format(path_name)) 543 return ''.join(result) 544 545 546def _CamelCaseToSnakeCase(path_name): 547 """Converts a field name from camelCase to snake_case.""" 548 result = [] 549 for c in path_name: 550 if c == '_': 551 raise ValueError('Fail to parse FieldMask: Path name ' 552 '{0} must not contain "_"s.'.format(path_name)) 553 if c.isupper(): 554 result += '_' 555 result += c.lower() 556 else: 557 result += c 558 return ''.join(result) 559 560 561class _FieldMaskTree(object): 562 """Represents a FieldMask in a tree structure. 563 564 For example, given a FieldMask "foo.bar,foo.baz,bar.baz", 565 the FieldMaskTree will be: 566 [_root] -+- foo -+- bar 567 | | 568 | +- baz 569 | 570 +- bar --- baz 571 In the tree, each leaf node represents a field path. 572 """ 573 574 __slots__ = ('_root',) 575 576 def __init__(self, field_mask=None): 577 """Initializes the tree by FieldMask.""" 578 self._root = {} 579 if field_mask: 580 self.MergeFromFieldMask(field_mask) 581 582 def MergeFromFieldMask(self, field_mask): 583 """Merges a FieldMask to the tree.""" 584 for path in field_mask.paths: 585 self.AddPath(path) 586 587 def AddPath(self, path): 588 """Adds a field path into the tree. 589 590 If the field path to add is a sub-path of an existing field path 591 in the tree (i.e., a leaf node), it means the tree already matches 592 the given path so nothing will be added to the tree. If the path 593 matches an existing non-leaf node in the tree, that non-leaf node 594 will be turned into a leaf node with all its children removed because 595 the path matches all the node's children. Otherwise, a new path will 596 be added. 597 598 Args: 599 path: The field path to add. 600 """ 601 node = self._root 602 for name in path.split('.'): 603 if name not in node: 604 node[name] = {} 605 elif not node[name]: 606 # Pre-existing empty node implies we already have this entire tree. 607 return 608 node = node[name] 609 # Remove any sub-trees we might have had. 610 node.clear() 611 612 def ToFieldMask(self, field_mask): 613 """Converts the tree to a FieldMask.""" 614 field_mask.Clear() 615 _AddFieldPaths(self._root, '', field_mask) 616 617 def IntersectPath(self, path, intersection): 618 """Calculates the intersection part of a field path with this tree. 619 620 Args: 621 path: The field path to calculates. 622 intersection: The out tree to record the intersection part. 623 """ 624 node = self._root 625 for name in path.split('.'): 626 if name not in node: 627 return 628 elif not node[name]: 629 intersection.AddPath(path) 630 return 631 node = node[name] 632 intersection.AddLeafNodes(path, node) 633 634 def AddLeafNodes(self, prefix, node): 635 """Adds leaf nodes begin with prefix to this tree.""" 636 if not node: 637 self.AddPath(prefix) 638 for name in node: 639 child_path = prefix + '.' + name 640 self.AddLeafNodes(child_path, node[name]) 641 642 def MergeMessage( 643 self, source, destination, 644 replace_message, replace_repeated): 645 """Merge all fields specified by this tree from source to destination.""" 646 _MergeMessage( 647 self._root, source, destination, replace_message, replace_repeated) 648 649 650def _StrConvert(value): 651 """Converts value to str if it is not.""" 652 # This file is imported by c extension and some methods like ClearField 653 # requires string for the field name. py2/py3 has different text 654 # type and may use unicode. 655 if not isinstance(value, str): 656 return value.encode('utf-8') 657 return value 658 659 660def _MergeMessage( 661 node, source, destination, replace_message, replace_repeated): 662 """Merge all fields specified by a sub-tree from source to destination.""" 663 source_descriptor = source.DESCRIPTOR 664 for name in node: 665 child = node[name] 666 field = source_descriptor.fields_by_name[name] 667 if field is None: 668 raise ValueError('Error: Can\'t find field {0} in message {1}.'.format( 669 name, source_descriptor.full_name)) 670 if child: 671 # Sub-paths are only allowed for singular message fields. 672 if (field.label == FieldDescriptor.LABEL_REPEATED or 673 field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE): 674 raise ValueError('Error: Field {0} in message {1} is not a singular ' 675 'message field and cannot have sub-fields.'.format( 676 name, source_descriptor.full_name)) 677 if source.HasField(name): 678 _MergeMessage( 679 child, getattr(source, name), getattr(destination, name), 680 replace_message, replace_repeated) 681 continue 682 if field.label == FieldDescriptor.LABEL_REPEATED: 683 if replace_repeated: 684 destination.ClearField(_StrConvert(name)) 685 repeated_source = getattr(source, name) 686 repeated_destination = getattr(destination, name) 687 repeated_destination.MergeFrom(repeated_source) 688 else: 689 if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: 690 if replace_message: 691 destination.ClearField(_StrConvert(name)) 692 if source.HasField(name): 693 getattr(destination, name).MergeFrom(getattr(source, name)) 694 else: 695 setattr(destination, name, getattr(source, name)) 696 697 698def _AddFieldPaths(node, prefix, field_mask): 699 """Adds the field paths descended from node to field_mask.""" 700 if not node and prefix: 701 field_mask.paths.append(prefix) 702 return 703 for name in sorted(node): 704 if prefix: 705 child_path = prefix + '.' + name 706 else: 707 child_path = name 708 _AddFieldPaths(node[name], child_path, field_mask) 709 710 711_INT_OR_FLOAT = six.integer_types + (float,) 712 713 714def _SetStructValue(struct_value, value): 715 if value is None: 716 struct_value.null_value = 0 717 elif isinstance(value, bool): 718 # Note: this check must come before the number check because in Python 719 # True and False are also considered numbers. 720 struct_value.bool_value = value 721 elif isinstance(value, six.string_types): 722 struct_value.string_value = value 723 elif isinstance(value, _INT_OR_FLOAT): 724 struct_value.number_value = value 725 elif isinstance(value, dict): 726 struct_value.struct_value.Clear() 727 struct_value.struct_value.update(value) 728 elif isinstance(value, list): 729 struct_value.list_value.Clear() 730 struct_value.list_value.extend(value) 731 else: 732 raise ValueError('Unexpected type') 733 734 735def _GetStructValue(struct_value): 736 which = struct_value.WhichOneof('kind') 737 if which == 'struct_value': 738 return struct_value.struct_value 739 elif which == 'null_value': 740 return None 741 elif which == 'number_value': 742 return struct_value.number_value 743 elif which == 'string_value': 744 return struct_value.string_value 745 elif which == 'bool_value': 746 return struct_value.bool_value 747 elif which == 'list_value': 748 return struct_value.list_value 749 elif which is None: 750 raise ValueError('Value not set') 751 752 753class Struct(object): 754 """Class for Struct message type.""" 755 756 __slots__ = () 757 758 def __getitem__(self, key): 759 return _GetStructValue(self.fields[key]) 760 761 def __contains__(self, item): 762 return item in self.fields 763 764 def __setitem__(self, key, value): 765 _SetStructValue(self.fields[key], value) 766 767 def __delitem__(self, key): 768 del self.fields[key] 769 770 def __len__(self): 771 return len(self.fields) 772 773 def __iter__(self): 774 return iter(self.fields) 775 776 def keys(self): # pylint: disable=invalid-name 777 return self.fields.keys() 778 779 def values(self): # pylint: disable=invalid-name 780 return [self[key] for key in self] 781 782 def items(self): # pylint: disable=invalid-name 783 return [(key, self[key]) for key in self] 784 785 def get_or_create_list(self, key): 786 """Returns a list for this key, creating if it didn't exist already.""" 787 if not self.fields[key].HasField('list_value'): 788 # Clear will mark list_value modified which will indeed create a list. 789 self.fields[key].list_value.Clear() 790 return self.fields[key].list_value 791 792 def get_or_create_struct(self, key): 793 """Returns a struct for this key, creating if it didn't exist already.""" 794 if not self.fields[key].HasField('struct_value'): 795 # Clear will mark struct_value modified which will indeed create a struct. 796 self.fields[key].struct_value.Clear() 797 return self.fields[key].struct_value 798 799 def update(self, dictionary): # pylint: disable=invalid-name 800 for key, value in dictionary.items(): 801 _SetStructValue(self.fields[key], value) 802 803collections_abc.MutableMapping.register(Struct) 804 805 806class ListValue(object): 807 """Class for ListValue message type.""" 808 809 __slots__ = () 810 811 def __len__(self): 812 return len(self.values) 813 814 def append(self, value): 815 _SetStructValue(self.values.add(), value) 816 817 def extend(self, elem_seq): 818 for value in elem_seq: 819 self.append(value) 820 821 def __getitem__(self, index): 822 """Retrieves item by the specified index.""" 823 return _GetStructValue(self.values.__getitem__(index)) 824 825 def __setitem__(self, index, value): 826 _SetStructValue(self.values.__getitem__(index), value) 827 828 def __delitem__(self, key): 829 del self.values[key] 830 831 def items(self): 832 for i in range(len(self)): 833 yield self[i] 834 835 def add_struct(self): 836 """Appends and returns a struct value as the next value in the list.""" 837 struct_value = self.values.add().struct_value 838 # Clear will mark struct_value modified which will indeed create a struct. 839 struct_value.Clear() 840 return struct_value 841 842 def add_list(self): 843 """Appends and returns a list value as the next value in the list.""" 844 list_value = self.values.add().list_value 845 # Clear will mark list_value modified which will indeed create a list. 846 list_value.Clear() 847 return list_value 848 849collections_abc.MutableSequence.register(ListValue) 850 851 852WKTBASES = { 853 'google.protobuf.Any': Any, 854 'google.protobuf.Duration': Duration, 855 'google.protobuf.FieldMask': FieldMask, 856 'google.protobuf.ListValue': ListValue, 857 'google.protobuf.Struct': Struct, 858 'google.protobuf.Timestamp': Timestamp, 859} 860