1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# 4# Use of this source code is governed by a BSD-style 5# license that can be found in the LICENSE file or at 6# https://developers.google.com/open-source/licenses/bsd 7 8"""Contains container classes to represent different protocol buffer types. 9 10This file defines container classes which represent categories of protocol 11buffer field types which need extra maintenance. Currently these categories 12are: 13 14- Repeated scalar fields - These are all repeated fields which aren't 15 composite (e.g. they are of simple types like int32, string, etc). 16- Repeated composite fields - Repeated fields which are composite. This 17 includes groups and nested messages. 18""" 19 20import collections.abc 21import copy 22import pickle 23from typing import ( 24 Any, 25 Iterable, 26 Iterator, 27 List, 28 MutableMapping, 29 MutableSequence, 30 NoReturn, 31 Optional, 32 Sequence, 33 TypeVar, 34 Union, 35 overload, 36) 37 38 39_T = TypeVar('_T') 40_K = TypeVar('_K') 41_V = TypeVar('_V') 42 43 44class BaseContainer(Sequence[_T]): 45 """Base container class.""" 46 47 # Minimizes memory usage and disallows assignment to other attributes. 48 __slots__ = ['_message_listener', '_values'] 49 50 def __init__(self, message_listener: Any) -> None: 51 """ 52 Args: 53 message_listener: A MessageListener implementation. 54 The RepeatedScalarFieldContainer will call this object's 55 Modified() method when it is modified. 56 """ 57 self._message_listener = message_listener 58 self._values = [] 59 60 @overload 61 def __getitem__(self, key: int) -> _T: 62 ... 63 64 @overload 65 def __getitem__(self, key: slice) -> List[_T]: 66 ... 67 68 def __getitem__(self, key): 69 """Retrieves item by the specified key.""" 70 return self._values[key] 71 72 def __len__(self) -> int: 73 """Returns the number of elements in the container.""" 74 return len(self._values) 75 76 def __ne__(self, other: Any) -> bool: 77 """Checks if another instance isn't equal to this one.""" 78 # The concrete classes should define __eq__. 79 return not self == other 80 81 __hash__ = None 82 83 def __repr__(self) -> str: 84 return repr(self._values) 85 86 def sort(self, *args, **kwargs) -> None: 87 # Continue to support the old sort_function keyword argument. 88 # This is expected to be a rare occurrence, so use LBYL to avoid 89 # the overhead of actually catching KeyError. 90 if 'sort_function' in kwargs: 91 kwargs['cmp'] = kwargs.pop('sort_function') 92 self._values.sort(*args, **kwargs) 93 94 def reverse(self) -> None: 95 self._values.reverse() 96 97 98# TODO: Remove this. BaseContainer does *not* conform to 99# MutableSequence, only its subclasses do. 100collections.abc.MutableSequence.register(BaseContainer) 101 102 103class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]): 104 """Simple, type-checked, list-like container for holding repeated scalars.""" 105 106 # Disallows assignment to other attributes. 107 __slots__ = ['_type_checker'] 108 109 def __init__( 110 self, 111 message_listener: Any, 112 type_checker: Any, 113 ) -> None: 114 """Args: 115 116 message_listener: A MessageListener implementation. The 117 RepeatedScalarFieldContainer will call this object's Modified() method 118 when it is modified. 119 type_checker: A type_checkers.ValueChecker instance to run on elements 120 inserted into this container. 121 """ 122 super().__init__(message_listener) 123 self._type_checker = type_checker 124 125 def append(self, value: _T) -> None: 126 """Appends an item to the list. Similar to list.append().""" 127 self._values.append(self._type_checker.CheckValue(value)) 128 if not self._message_listener.dirty: 129 self._message_listener.Modified() 130 131 def insert(self, key: int, value: _T) -> None: 132 """Inserts the item at the specified position. Similar to list.insert().""" 133 self._values.insert(key, self._type_checker.CheckValue(value)) 134 if not self._message_listener.dirty: 135 self._message_listener.Modified() 136 137 def extend(self, elem_seq: Iterable[_T]) -> None: 138 """Extends by appending the given iterable. Similar to list.extend().""" 139 elem_seq_iter = iter(elem_seq) 140 new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter] 141 if new_values: 142 self._values.extend(new_values) 143 self._message_listener.Modified() 144 145 def MergeFrom( 146 self, 147 other: Union['RepeatedScalarFieldContainer[_T]', Iterable[_T]], 148 ) -> None: 149 """Appends the contents of another repeated field of the same type to this 150 one. We do not check the types of the individual fields. 151 """ 152 self._values.extend(other) 153 self._message_listener.Modified() 154 155 def remove(self, elem: _T): 156 """Removes an item from the list. Similar to list.remove().""" 157 self._values.remove(elem) 158 self._message_listener.Modified() 159 160 def pop(self, key: Optional[int] = -1) -> _T: 161 """Removes and returns an item at a given index. Similar to list.pop().""" 162 value = self._values[key] 163 self.__delitem__(key) 164 return value 165 166 @overload 167 def __setitem__(self, key: int, value: _T) -> None: 168 ... 169 170 @overload 171 def __setitem__(self, key: slice, value: Iterable[_T]) -> None: 172 ... 173 174 def __setitem__(self, key, value) -> None: 175 """Sets the item on the specified position.""" 176 if isinstance(key, slice): 177 if key.step is not None: 178 raise ValueError('Extended slices not supported') 179 self._values[key] = map(self._type_checker.CheckValue, value) 180 self._message_listener.Modified() 181 else: 182 self._values[key] = self._type_checker.CheckValue(value) 183 self._message_listener.Modified() 184 185 def __delitem__(self, key: Union[int, slice]) -> None: 186 """Deletes the item at the specified position.""" 187 del self._values[key] 188 self._message_listener.Modified() 189 190 def __eq__(self, other: Any) -> bool: 191 """Compares the current instance with another one.""" 192 if self is other: 193 return True 194 # Special case for the same type which should be common and fast. 195 if isinstance(other, self.__class__): 196 return other._values == self._values 197 # We are presumably comparing against some other sequence type. 198 return other == self._values 199 200 def __deepcopy__( 201 self, 202 unused_memo: Any = None, 203 ) -> 'RepeatedScalarFieldContainer[_T]': 204 clone = RepeatedScalarFieldContainer( 205 copy.deepcopy(self._message_listener), self._type_checker) 206 clone.MergeFrom(self) 207 return clone 208 209 def __reduce__(self, **kwargs) -> NoReturn: 210 raise pickle.PickleError( 211 "Can't pickle repeated scalar fields, convert to list first") 212 213 214# TODO: Constrain T to be a subtype of Message. 215class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]): 216 """Simple, list-like container for holding repeated composite fields.""" 217 218 # Disallows assignment to other attributes. 219 __slots__ = ['_message_descriptor'] 220 221 def __init__(self, message_listener: Any, message_descriptor: Any) -> None: 222 """ 223 Note that we pass in a descriptor instead of the generated directly, 224 since at the time we construct a _RepeatedCompositeFieldContainer we 225 haven't yet necessarily initialized the type that will be contained in the 226 container. 227 228 Args: 229 message_listener: A MessageListener implementation. 230 The RepeatedCompositeFieldContainer will call this object's 231 Modified() method when it is modified. 232 message_descriptor: A Descriptor instance describing the protocol type 233 that should be present in this container. We'll use the 234 _concrete_class field of this descriptor when the client calls add(). 235 """ 236 super().__init__(message_listener) 237 self._message_descriptor = message_descriptor 238 239 def add(self, **kwargs: Any) -> _T: 240 """Adds a new element at the end of the list and returns it. Keyword 241 arguments may be used to initialize the element. 242 """ 243 new_element = self._message_descriptor._concrete_class(**kwargs) 244 new_element._SetListener(self._message_listener) 245 self._values.append(new_element) 246 if not self._message_listener.dirty: 247 self._message_listener.Modified() 248 return new_element 249 250 def append(self, value: _T) -> None: 251 """Appends one element by copying the message.""" 252 new_element = self._message_descriptor._concrete_class() 253 new_element._SetListener(self._message_listener) 254 new_element.CopyFrom(value) 255 self._values.append(new_element) 256 if not self._message_listener.dirty: 257 self._message_listener.Modified() 258 259 def insert(self, key: int, value: _T) -> None: 260 """Inserts the item at the specified position by copying.""" 261 new_element = self._message_descriptor._concrete_class() 262 new_element._SetListener(self._message_listener) 263 new_element.CopyFrom(value) 264 self._values.insert(key, new_element) 265 if not self._message_listener.dirty: 266 self._message_listener.Modified() 267 268 def extend(self, elem_seq: Iterable[_T]) -> None: 269 """Extends by appending the given sequence of elements of the same type 270 271 as this one, copying each individual message. 272 """ 273 message_class = self._message_descriptor._concrete_class 274 listener = self._message_listener 275 values = self._values 276 for message in elem_seq: 277 new_element = message_class() 278 new_element._SetListener(listener) 279 new_element.MergeFrom(message) 280 values.append(new_element) 281 listener.Modified() 282 283 def MergeFrom( 284 self, 285 other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]], 286 ) -> None: 287 """Appends the contents of another repeated field of the same type to this 288 one, copying each individual message. 289 """ 290 self.extend(other) 291 292 def remove(self, elem: _T) -> None: 293 """Removes an item from the list. Similar to list.remove().""" 294 self._values.remove(elem) 295 self._message_listener.Modified() 296 297 def pop(self, key: Optional[int] = -1) -> _T: 298 """Removes and returns an item at a given index. Similar to list.pop().""" 299 value = self._values[key] 300 self.__delitem__(key) 301 return value 302 303 @overload 304 def __setitem__(self, key: int, value: _T) -> None: 305 ... 306 307 @overload 308 def __setitem__(self, key: slice, value: Iterable[_T]) -> None: 309 ... 310 311 def __setitem__(self, key, value): 312 # This method is implemented to make RepeatedCompositeFieldContainer 313 # structurally compatible with typing.MutableSequence. It is 314 # otherwise unsupported and will always raise an error. 315 raise TypeError( 316 f'{self.__class__.__name__} object does not support item assignment') 317 318 def __delitem__(self, key: Union[int, slice]) -> None: 319 """Deletes the item at the specified position.""" 320 del self._values[key] 321 self._message_listener.Modified() 322 323 def __eq__(self, other: Any) -> bool: 324 """Compares the current instance with another one.""" 325 if self is other: 326 return True 327 if not isinstance(other, self.__class__): 328 raise TypeError('Can only compare repeated composite fields against ' 329 'other repeated composite fields.') 330 return self._values == other._values 331 332 333class ScalarMap(MutableMapping[_K, _V]): 334 """Simple, type-checked, dict-like container for holding repeated scalars.""" 335 336 # Disallows assignment to other attributes. 337 __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener', 338 '_entry_descriptor'] 339 340 def __init__( 341 self, 342 message_listener: Any, 343 key_checker: Any, 344 value_checker: Any, 345 entry_descriptor: Any, 346 ) -> None: 347 """ 348 Args: 349 message_listener: A MessageListener implementation. 350 The ScalarMap will call this object's Modified() method when it 351 is modified. 352 key_checker: A type_checkers.ValueChecker instance to run on keys 353 inserted into this container. 354 value_checker: A type_checkers.ValueChecker instance to run on values 355 inserted into this container. 356 entry_descriptor: The MessageDescriptor of a map entry: key and value. 357 """ 358 self._message_listener = message_listener 359 self._key_checker = key_checker 360 self._value_checker = value_checker 361 self._entry_descriptor = entry_descriptor 362 self._values = {} 363 364 def __getitem__(self, key: _K) -> _V: 365 try: 366 return self._values[key] 367 except KeyError: 368 key = self._key_checker.CheckValue(key) 369 val = self._value_checker.DefaultValue() 370 self._values[key] = val 371 return val 372 373 def __contains__(self, item: _K) -> bool: 374 # We check the key's type to match the strong-typing flavor of the API. 375 # Also this makes it easier to match the behavior of the C++ implementation. 376 self._key_checker.CheckValue(item) 377 return item in self._values 378 379 @overload 380 def get(self, key: _K) -> Optional[_V]: 381 ... 382 383 @overload 384 def get(self, key: _K, default: _T) -> Union[_V, _T]: 385 ... 386 387 # We need to override this explicitly, because our defaultdict-like behavior 388 # will make the default implementation (from our base class) always insert 389 # the key. 390 def get(self, key, default=None): 391 if key in self: 392 return self[key] 393 else: 394 return default 395 396 def __setitem__(self, key: _K, value: _V) -> _T: 397 checked_key = self._key_checker.CheckValue(key) 398 checked_value = self._value_checker.CheckValue(value) 399 self._values[checked_key] = checked_value 400 self._message_listener.Modified() 401 402 def __delitem__(self, key: _K) -> None: 403 del self._values[key] 404 self._message_listener.Modified() 405 406 def __len__(self) -> int: 407 return len(self._values) 408 409 def __iter__(self) -> Iterator[_K]: 410 return iter(self._values) 411 412 def __repr__(self) -> str: 413 return repr(self._values) 414 415 def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None: 416 self._values.update(other._values) 417 self._message_listener.Modified() 418 419 def InvalidateIterators(self) -> None: 420 # It appears that the only way to reliably invalidate iterators to 421 # self._values is to ensure that its size changes. 422 original = self._values 423 self._values = original.copy() 424 original[None] = None 425 426 # This is defined in the abstract base, but we can do it much more cheaply. 427 def clear(self) -> None: 428 self._values.clear() 429 self._message_listener.Modified() 430 431 def GetEntryClass(self) -> Any: 432 return self._entry_descriptor._concrete_class 433 434 435class MessageMap(MutableMapping[_K, _V]): 436 """Simple, type-checked, dict-like container for with submessage values.""" 437 438 # Disallows assignment to other attributes. 439 __slots__ = ['_key_checker', '_values', '_message_listener', 440 '_message_descriptor', '_entry_descriptor'] 441 442 def __init__( 443 self, 444 message_listener: Any, 445 message_descriptor: Any, 446 key_checker: Any, 447 entry_descriptor: Any, 448 ) -> None: 449 """ 450 Args: 451 message_listener: A MessageListener implementation. 452 The ScalarMap will call this object's Modified() method when it 453 is modified. 454 key_checker: A type_checkers.ValueChecker instance to run on keys 455 inserted into this container. 456 value_checker: A type_checkers.ValueChecker instance to run on values 457 inserted into this container. 458 entry_descriptor: The MessageDescriptor of a map entry: key and value. 459 """ 460 self._message_listener = message_listener 461 self._message_descriptor = message_descriptor 462 self._key_checker = key_checker 463 self._entry_descriptor = entry_descriptor 464 self._values = {} 465 466 def __getitem__(self, key: _K) -> _V: 467 key = self._key_checker.CheckValue(key) 468 try: 469 return self._values[key] 470 except KeyError: 471 new_element = self._message_descriptor._concrete_class() 472 new_element._SetListener(self._message_listener) 473 self._values[key] = new_element 474 self._message_listener.Modified() 475 return new_element 476 477 def get_or_create(self, key: _K) -> _V: 478 """get_or_create() is an alias for getitem (ie. map[key]). 479 480 Args: 481 key: The key to get or create in the map. 482 483 This is useful in cases where you want to be explicit that the call is 484 mutating the map. This can avoid lint errors for statements like this 485 that otherwise would appear to be pointless statements: 486 487 msg.my_map[key] 488 """ 489 return self[key] 490 491 @overload 492 def get(self, key: _K) -> Optional[_V]: 493 ... 494 495 @overload 496 def get(self, key: _K, default: _T) -> Union[_V, _T]: 497 ... 498 499 # We need to override this explicitly, because our defaultdict-like behavior 500 # will make the default implementation (from our base class) always insert 501 # the key. 502 def get(self, key, default=None): 503 if key in self: 504 return self[key] 505 else: 506 return default 507 508 def __contains__(self, item: _K) -> bool: 509 item = self._key_checker.CheckValue(item) 510 return item in self._values 511 512 def __setitem__(self, key: _K, value: _V) -> NoReturn: 513 raise ValueError('May not set values directly, call my_map[key].foo = 5') 514 515 def __delitem__(self, key: _K) -> None: 516 key = self._key_checker.CheckValue(key) 517 del self._values[key] 518 self._message_listener.Modified() 519 520 def __len__(self) -> int: 521 return len(self._values) 522 523 def __iter__(self) -> Iterator[_K]: 524 return iter(self._values) 525 526 def __repr__(self) -> str: 527 return repr(self._values) 528 529 def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None: 530 # pylint: disable=protected-access 531 for key in other._values: 532 # According to documentation: "When parsing from the wire or when merging, 533 # if there are duplicate map keys the last key seen is used". 534 if key in self: 535 del self[key] 536 self[key].CopyFrom(other[key]) 537 # self._message_listener.Modified() not required here, because 538 # mutations to submessages already propagate. 539 540 def InvalidateIterators(self) -> None: 541 # It appears that the only way to reliably invalidate iterators to 542 # self._values is to ensure that its size changes. 543 original = self._values 544 self._values = original.copy() 545 original[None] = None 546 547 # This is defined in the abstract base, but we can do it much more cheaply. 548 def clear(self) -> None: 549 self._values.clear() 550 self._message_listener.Modified() 551 552 def GetEntryClass(self) -> Any: 553 return self._entry_descriptor._concrete_class 554 555 556class _UnknownField: 557 """A parsed unknown field.""" 558 559 # Disallows assignment to other attributes. 560 __slots__ = ['_field_number', '_wire_type', '_data'] 561 562 def __init__(self, field_number, wire_type, data): 563 self._field_number = field_number 564 self._wire_type = wire_type 565 self._data = data 566 return 567 568 def __lt__(self, other): 569 # pylint: disable=protected-access 570 return self._field_number < other._field_number 571 572 def __eq__(self, other): 573 if self is other: 574 return True 575 # pylint: disable=protected-access 576 return (self._field_number == other._field_number and 577 self._wire_type == other._wire_type and 578 self._data == other._data) 579 580 581class UnknownFieldRef: # pylint: disable=missing-class-docstring 582 583 def __init__(self, parent, index): 584 self._parent = parent 585 self._index = index 586 587 def _check_valid(self): 588 if not self._parent: 589 raise ValueError('UnknownField does not exist. ' 590 'The parent message might be cleared.') 591 if self._index >= len(self._parent): 592 raise ValueError('UnknownField does not exist. ' 593 'The parent message might be cleared.') 594 595 @property 596 def field_number(self): 597 self._check_valid() 598 # pylint: disable=protected-access 599 return self._parent._internal_get(self._index)._field_number 600 601 @property 602 def wire_type(self): 603 self._check_valid() 604 # pylint: disable=protected-access 605 return self._parent._internal_get(self._index)._wire_type 606 607 @property 608 def data(self): 609 self._check_valid() 610 # pylint: disable=protected-access 611 return self._parent._internal_get(self._index)._data 612 613 614class UnknownFieldSet: 615 """UnknownField container""" 616 617 # Disallows assignment to other attributes. 618 __slots__ = ['_values'] 619 620 def __init__(self): 621 self._values = [] 622 623 def __getitem__(self, index): 624 if self._values is None: 625 raise ValueError('UnknownFields does not exist. ' 626 'The parent message might be cleared.') 627 size = len(self._values) 628 if index < 0: 629 index += size 630 if index < 0 or index >= size: 631 raise IndexError('index %d out of range'.index) 632 633 return UnknownFieldRef(self, index) 634 635 def _internal_get(self, index): 636 return self._values[index] 637 638 def __len__(self): 639 if self._values is None: 640 raise ValueError('UnknownFields does not exist. ' 641 'The parent message might be cleared.') 642 return len(self._values) 643 644 def _add(self, field_number, wire_type, data): 645 unknown_field = _UnknownField(field_number, wire_type, data) 646 self._values.append(unknown_field) 647 return unknown_field 648 649 def __iter__(self): 650 for i in range(len(self)): 651 yield UnknownFieldRef(self, i) 652 653 def _extend(self, other): 654 if other is None: 655 return 656 # pylint: disable=protected-access 657 self._values.extend(other._values) 658 659 def __eq__(self, other): 660 if self is other: 661 return True 662 # Sort unknown fields because their order shouldn't 663 # affect equality test. 664 values = list(self._values) 665 if other is None: 666 return not values 667 values.sort() 668 # pylint: disable=protected-access 669 other_values = sorted(other._values) 670 return values == other_values 671 672 def _clear(self): 673 for value in self._values: 674 # pylint: disable=protected-access 675 if isinstance(value._data, UnknownFieldSet): 676 value._data._clear() # pylint: disable=protected-access 677 self._values = None 678