• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Contains container classes to represent different protocol buffer types.
9
10This file defines container classes which represent categories of protocol
11buffer field types which need extra maintenance. Currently these categories
12are:
13
14-   Repeated scalar fields - These are all repeated fields which aren't
15    composite (e.g. they are of simple types like int32, string, etc).
16-   Repeated composite fields - Repeated fields which are composite. This
17    includes groups and nested messages.
18"""
19
20import collections.abc
21import copy
22import pickle
23from typing import (
24    Any,
25    Iterable,
26    Iterator,
27    List,
28    MutableMapping,
29    MutableSequence,
30    NoReturn,
31    Optional,
32    Sequence,
33    TypeVar,
34    Union,
35    overload,
36)
37
38
39_T = TypeVar('_T')
40_K = TypeVar('_K')
41_V = TypeVar('_V')
42
43
44class BaseContainer(Sequence[_T]):
45  """Base container class."""
46
47  # Minimizes memory usage and disallows assignment to other attributes.
48  __slots__ = ['_message_listener', '_values']
49
50  def __init__(self, message_listener: Any) -> None:
51    """
52    Args:
53      message_listener: A MessageListener implementation.
54        The RepeatedScalarFieldContainer will call this object's
55        Modified() method when it is modified.
56    """
57    self._message_listener = message_listener
58    self._values = []
59
60  @overload
61  def __getitem__(self, key: int) -> _T:
62    ...
63
64  @overload
65  def __getitem__(self, key: slice) -> List[_T]:
66    ...
67
68  def __getitem__(self, key):
69    """Retrieves item by the specified key."""
70    return self._values[key]
71
72  def __len__(self) -> int:
73    """Returns the number of elements in the container."""
74    return len(self._values)
75
76  def __ne__(self, other: Any) -> bool:
77    """Checks if another instance isn't equal to this one."""
78    # The concrete classes should define __eq__.
79    return not self == other
80
81  __hash__ = None
82
83  def __repr__(self) -> str:
84    return repr(self._values)
85
86  def sort(self, *args, **kwargs) -> None:
87    # Continue to support the old sort_function keyword argument.
88    # This is expected to be a rare occurrence, so use LBYL to avoid
89    # the overhead of actually catching KeyError.
90    if 'sort_function' in kwargs:
91      kwargs['cmp'] = kwargs.pop('sort_function')
92    self._values.sort(*args, **kwargs)
93
94  def reverse(self) -> None:
95    self._values.reverse()
96
97
98# TODO: Remove this. BaseContainer does *not* conform to
99# MutableSequence, only its subclasses do.
100collections.abc.MutableSequence.register(BaseContainer)
101
102
103class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]):
104  """Simple, type-checked, list-like container for holding repeated scalars."""
105
106  # Disallows assignment to other attributes.
107  __slots__ = ['_type_checker']
108
109  def __init__(
110      self,
111      message_listener: Any,
112      type_checker: Any,
113  ) -> None:
114    """Args:
115
116      message_listener: A MessageListener implementation. The
117      RepeatedScalarFieldContainer will call this object's Modified() method
118      when it is modified.
119      type_checker: A type_checkers.ValueChecker instance to run on elements
120      inserted into this container.
121    """
122    super().__init__(message_listener)
123    self._type_checker = type_checker
124
125  def append(self, value: _T) -> None:
126    """Appends an item to the list. Similar to list.append()."""
127    self._values.append(self._type_checker.CheckValue(value))
128    if not self._message_listener.dirty:
129      self._message_listener.Modified()
130
131  def insert(self, key: int, value: _T) -> None:
132    """Inserts the item at the specified position. Similar to list.insert()."""
133    self._values.insert(key, self._type_checker.CheckValue(value))
134    if not self._message_listener.dirty:
135      self._message_listener.Modified()
136
137  def extend(self, elem_seq: Iterable[_T]) -> None:
138    """Extends by appending the given iterable. Similar to list.extend()."""
139    elem_seq_iter = iter(elem_seq)
140    new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
141    if new_values:
142      self._values.extend(new_values)
143    self._message_listener.Modified()
144
145  def MergeFrom(
146      self,
147      other: Union['RepeatedScalarFieldContainer[_T]', Iterable[_T]],
148  ) -> None:
149    """Appends the contents of another repeated field of the same type to this
150    one. We do not check the types of the individual fields.
151    """
152    self._values.extend(other)
153    self._message_listener.Modified()
154
155  def remove(self, elem: _T):
156    """Removes an item from the list. Similar to list.remove()."""
157    self._values.remove(elem)
158    self._message_listener.Modified()
159
160  def pop(self, key: Optional[int] = -1) -> _T:
161    """Removes and returns an item at a given index. Similar to list.pop()."""
162    value = self._values[key]
163    self.__delitem__(key)
164    return value
165
166  @overload
167  def __setitem__(self, key: int, value: _T) -> None:
168    ...
169
170  @overload
171  def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
172    ...
173
174  def __setitem__(self, key, value) -> None:
175    """Sets the item on the specified position."""
176    if isinstance(key, slice):
177      if key.step is not None:
178        raise ValueError('Extended slices not supported')
179      self._values[key] = map(self._type_checker.CheckValue, value)
180      self._message_listener.Modified()
181    else:
182      self._values[key] = self._type_checker.CheckValue(value)
183      self._message_listener.Modified()
184
185  def __delitem__(self, key: Union[int, slice]) -> None:
186    """Deletes the item at the specified position."""
187    del self._values[key]
188    self._message_listener.Modified()
189
190  def __eq__(self, other: Any) -> bool:
191    """Compares the current instance with another one."""
192    if self is other:
193      return True
194    # Special case for the same type which should be common and fast.
195    if isinstance(other, self.__class__):
196      return other._values == self._values
197    # We are presumably comparing against some other sequence type.
198    return other == self._values
199
200  def __deepcopy__(
201      self,
202      unused_memo: Any = None,
203  ) -> 'RepeatedScalarFieldContainer[_T]':
204    clone = RepeatedScalarFieldContainer(
205        copy.deepcopy(self._message_listener), self._type_checker)
206    clone.MergeFrom(self)
207    return clone
208
209  def __reduce__(self, **kwargs) -> NoReturn:
210    raise pickle.PickleError(
211        "Can't pickle repeated scalar fields, convert to list first")
212
213
214# TODO: Constrain T to be a subtype of Message.
215class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):
216  """Simple, list-like container for holding repeated composite fields."""
217
218  # Disallows assignment to other attributes.
219  __slots__ = ['_message_descriptor']
220
221  def __init__(self, message_listener: Any, message_descriptor: Any) -> None:
222    """
223    Note that we pass in a descriptor instead of the generated directly,
224    since at the time we construct a _RepeatedCompositeFieldContainer we
225    haven't yet necessarily initialized the type that will be contained in the
226    container.
227
228    Args:
229      message_listener: A MessageListener implementation.
230        The RepeatedCompositeFieldContainer will call this object's
231        Modified() method when it is modified.
232      message_descriptor: A Descriptor instance describing the protocol type
233        that should be present in this container.  We'll use the
234        _concrete_class field of this descriptor when the client calls add().
235    """
236    super().__init__(message_listener)
237    self._message_descriptor = message_descriptor
238
239  def add(self, **kwargs: Any) -> _T:
240    """Adds a new element at the end of the list and returns it. Keyword
241    arguments may be used to initialize the element.
242    """
243    new_element = self._message_descriptor._concrete_class(**kwargs)
244    new_element._SetListener(self._message_listener)
245    self._values.append(new_element)
246    if not self._message_listener.dirty:
247      self._message_listener.Modified()
248    return new_element
249
250  def append(self, value: _T) -> None:
251    """Appends one element by copying the message."""
252    new_element = self._message_descriptor._concrete_class()
253    new_element._SetListener(self._message_listener)
254    new_element.CopyFrom(value)
255    self._values.append(new_element)
256    if not self._message_listener.dirty:
257      self._message_listener.Modified()
258
259  def insert(self, key: int, value: _T) -> None:
260    """Inserts the item at the specified position by copying."""
261    new_element = self._message_descriptor._concrete_class()
262    new_element._SetListener(self._message_listener)
263    new_element.CopyFrom(value)
264    self._values.insert(key, new_element)
265    if not self._message_listener.dirty:
266      self._message_listener.Modified()
267
268  def extend(self, elem_seq: Iterable[_T]) -> None:
269    """Extends by appending the given sequence of elements of the same type
270
271    as this one, copying each individual message.
272    """
273    message_class = self._message_descriptor._concrete_class
274    listener = self._message_listener
275    values = self._values
276    for message in elem_seq:
277      new_element = message_class()
278      new_element._SetListener(listener)
279      new_element.MergeFrom(message)
280      values.append(new_element)
281    listener.Modified()
282
283  def MergeFrom(
284      self,
285      other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]],
286  ) -> None:
287    """Appends the contents of another repeated field of the same type to this
288    one, copying each individual message.
289    """
290    self.extend(other)
291
292  def remove(self, elem: _T) -> None:
293    """Removes an item from the list. Similar to list.remove()."""
294    self._values.remove(elem)
295    self._message_listener.Modified()
296
297  def pop(self, key: Optional[int] = -1) -> _T:
298    """Removes and returns an item at a given index. Similar to list.pop()."""
299    value = self._values[key]
300    self.__delitem__(key)
301    return value
302
303  @overload
304  def __setitem__(self, key: int, value: _T) -> None:
305    ...
306
307  @overload
308  def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
309    ...
310
311  def __setitem__(self, key, value):
312    # This method is implemented to make RepeatedCompositeFieldContainer
313    # structurally compatible with typing.MutableSequence. It is
314    # otherwise unsupported and will always raise an error.
315    raise TypeError(
316        f'{self.__class__.__name__} object does not support item assignment')
317
318  def __delitem__(self, key: Union[int, slice]) -> None:
319    """Deletes the item at the specified position."""
320    del self._values[key]
321    self._message_listener.Modified()
322
323  def __eq__(self, other: Any) -> bool:
324    """Compares the current instance with another one."""
325    if self is other:
326      return True
327    if not isinstance(other, self.__class__):
328      raise TypeError('Can only compare repeated composite fields against '
329                      'other repeated composite fields.')
330    return self._values == other._values
331
332
333class ScalarMap(MutableMapping[_K, _V]):
334  """Simple, type-checked, dict-like container for holding repeated scalars."""
335
336  # Disallows assignment to other attributes.
337  __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener',
338               '_entry_descriptor']
339
340  def __init__(
341      self,
342      message_listener: Any,
343      key_checker: Any,
344      value_checker: Any,
345      entry_descriptor: Any,
346  ) -> None:
347    """
348    Args:
349      message_listener: A MessageListener implementation.
350        The ScalarMap will call this object's Modified() method when it
351        is modified.
352      key_checker: A type_checkers.ValueChecker instance to run on keys
353        inserted into this container.
354      value_checker: A type_checkers.ValueChecker instance to run on values
355        inserted into this container.
356      entry_descriptor: The MessageDescriptor of a map entry: key and value.
357    """
358    self._message_listener = message_listener
359    self._key_checker = key_checker
360    self._value_checker = value_checker
361    self._entry_descriptor = entry_descriptor
362    self._values = {}
363
364  def __getitem__(self, key: _K) -> _V:
365    try:
366      return self._values[key]
367    except KeyError:
368      key = self._key_checker.CheckValue(key)
369      val = self._value_checker.DefaultValue()
370      self._values[key] = val
371      return val
372
373  def __contains__(self, item: _K) -> bool:
374    # We check the key's type to match the strong-typing flavor of the API.
375    # Also this makes it easier to match the behavior of the C++ implementation.
376    self._key_checker.CheckValue(item)
377    return item in self._values
378
379  @overload
380  def get(self, key: _K) -> Optional[_V]:
381    ...
382
383  @overload
384  def get(self, key: _K, default: _T) -> Union[_V, _T]:
385    ...
386
387  # We need to override this explicitly, because our defaultdict-like behavior
388  # will make the default implementation (from our base class) always insert
389  # the key.
390  def get(self, key, default=None):
391    if key in self:
392      return self[key]
393    else:
394      return default
395
396  def __setitem__(self, key: _K, value: _V) -> _T:
397    checked_key = self._key_checker.CheckValue(key)
398    checked_value = self._value_checker.CheckValue(value)
399    self._values[checked_key] = checked_value
400    self._message_listener.Modified()
401
402  def __delitem__(self, key: _K) -> None:
403    del self._values[key]
404    self._message_listener.Modified()
405
406  def __len__(self) -> int:
407    return len(self._values)
408
409  def __iter__(self) -> Iterator[_K]:
410    return iter(self._values)
411
412  def __repr__(self) -> str:
413    return repr(self._values)
414
415  def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None:
416    self._values.update(other._values)
417    self._message_listener.Modified()
418
419  def InvalidateIterators(self) -> None:
420    # It appears that the only way to reliably invalidate iterators to
421    # self._values is to ensure that its size changes.
422    original = self._values
423    self._values = original.copy()
424    original[None] = None
425
426  # This is defined in the abstract base, but we can do it much more cheaply.
427  def clear(self) -> None:
428    self._values.clear()
429    self._message_listener.Modified()
430
431  def GetEntryClass(self) -> Any:
432    return self._entry_descriptor._concrete_class
433
434
435class MessageMap(MutableMapping[_K, _V]):
436  """Simple, type-checked, dict-like container for with submessage values."""
437
438  # Disallows assignment to other attributes.
439  __slots__ = ['_key_checker', '_values', '_message_listener',
440               '_message_descriptor', '_entry_descriptor']
441
442  def __init__(
443      self,
444      message_listener: Any,
445      message_descriptor: Any,
446      key_checker: Any,
447      entry_descriptor: Any,
448  ) -> None:
449    """
450    Args:
451      message_listener: A MessageListener implementation.
452        The ScalarMap will call this object's Modified() method when it
453        is modified.
454      key_checker: A type_checkers.ValueChecker instance to run on keys
455        inserted into this container.
456      value_checker: A type_checkers.ValueChecker instance to run on values
457        inserted into this container.
458      entry_descriptor: The MessageDescriptor of a map entry: key and value.
459    """
460    self._message_listener = message_listener
461    self._message_descriptor = message_descriptor
462    self._key_checker = key_checker
463    self._entry_descriptor = entry_descriptor
464    self._values = {}
465
466  def __getitem__(self, key: _K) -> _V:
467    key = self._key_checker.CheckValue(key)
468    try:
469      return self._values[key]
470    except KeyError:
471      new_element = self._message_descriptor._concrete_class()
472      new_element._SetListener(self._message_listener)
473      self._values[key] = new_element
474      self._message_listener.Modified()
475      return new_element
476
477  def get_or_create(self, key: _K) -> _V:
478    """get_or_create() is an alias for getitem (ie. map[key]).
479
480    Args:
481      key: The key to get or create in the map.
482
483    This is useful in cases where you want to be explicit that the call is
484    mutating the map.  This can avoid lint errors for statements like this
485    that otherwise would appear to be pointless statements:
486
487      msg.my_map[key]
488    """
489    return self[key]
490
491  @overload
492  def get(self, key: _K) -> Optional[_V]:
493    ...
494
495  @overload
496  def get(self, key: _K, default: _T) -> Union[_V, _T]:
497    ...
498
499  # We need to override this explicitly, because our defaultdict-like behavior
500  # will make the default implementation (from our base class) always insert
501  # the key.
502  def get(self, key, default=None):
503    if key in self:
504      return self[key]
505    else:
506      return default
507
508  def __contains__(self, item: _K) -> bool:
509    item = self._key_checker.CheckValue(item)
510    return item in self._values
511
512  def __setitem__(self, key: _K, value: _V) -> NoReturn:
513    raise ValueError('May not set values directly, call my_map[key].foo = 5')
514
515  def __delitem__(self, key: _K) -> None:
516    key = self._key_checker.CheckValue(key)
517    del self._values[key]
518    self._message_listener.Modified()
519
520  def __len__(self) -> int:
521    return len(self._values)
522
523  def __iter__(self) -> Iterator[_K]:
524    return iter(self._values)
525
526  def __repr__(self) -> str:
527    return repr(self._values)
528
529  def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None:
530    # pylint: disable=protected-access
531    for key in other._values:
532      # According to documentation: "When parsing from the wire or when merging,
533      # if there are duplicate map keys the last key seen is used".
534      if key in self:
535        del self[key]
536      self[key].CopyFrom(other[key])
537    # self._message_listener.Modified() not required here, because
538    # mutations to submessages already propagate.
539
540  def InvalidateIterators(self) -> None:
541    # It appears that the only way to reliably invalidate iterators to
542    # self._values is to ensure that its size changes.
543    original = self._values
544    self._values = original.copy()
545    original[None] = None
546
547  # This is defined in the abstract base, but we can do it much more cheaply.
548  def clear(self) -> None:
549    self._values.clear()
550    self._message_listener.Modified()
551
552  def GetEntryClass(self) -> Any:
553    return self._entry_descriptor._concrete_class
554
555
556class _UnknownField:
557  """A parsed unknown field."""
558
559  # Disallows assignment to other attributes.
560  __slots__ = ['_field_number', '_wire_type', '_data']
561
562  def __init__(self, field_number, wire_type, data):
563    self._field_number = field_number
564    self._wire_type = wire_type
565    self._data = data
566    return
567
568  def __lt__(self, other):
569    # pylint: disable=protected-access
570    return self._field_number < other._field_number
571
572  def __eq__(self, other):
573    if self is other:
574      return True
575    # pylint: disable=protected-access
576    return (self._field_number == other._field_number and
577            self._wire_type == other._wire_type and
578            self._data == other._data)
579
580
581class UnknownFieldRef:  # pylint: disable=missing-class-docstring
582
583  def __init__(self, parent, index):
584    self._parent = parent
585    self._index = index
586
587  def _check_valid(self):
588    if not self._parent:
589      raise ValueError('UnknownField does not exist. '
590                       'The parent message might be cleared.')
591    if self._index >= len(self._parent):
592      raise ValueError('UnknownField does not exist. '
593                       'The parent message might be cleared.')
594
595  @property
596  def field_number(self):
597    self._check_valid()
598    # pylint: disable=protected-access
599    return self._parent._internal_get(self._index)._field_number
600
601  @property
602  def wire_type(self):
603    self._check_valid()
604    # pylint: disable=protected-access
605    return self._parent._internal_get(self._index)._wire_type
606
607  @property
608  def data(self):
609    self._check_valid()
610    # pylint: disable=protected-access
611    return self._parent._internal_get(self._index)._data
612
613
614class UnknownFieldSet:
615  """UnknownField container"""
616
617  # Disallows assignment to other attributes.
618  __slots__ = ['_values']
619
620  def __init__(self):
621    self._values = []
622
623  def __getitem__(self, index):
624    if self._values is None:
625      raise ValueError('UnknownFields does not exist. '
626                       'The parent message might be cleared.')
627    size = len(self._values)
628    if index < 0:
629      index += size
630    if index < 0 or index >= size:
631      raise IndexError('index %d out of range'.index)
632
633    return UnknownFieldRef(self, index)
634
635  def _internal_get(self, index):
636    return self._values[index]
637
638  def __len__(self):
639    if self._values is None:
640      raise ValueError('UnknownFields does not exist. '
641                       'The parent message might be cleared.')
642    return len(self._values)
643
644  def _add(self, field_number, wire_type, data):
645    unknown_field = _UnknownField(field_number, wire_type, data)
646    self._values.append(unknown_field)
647    return unknown_field
648
649  def __iter__(self):
650    for i in range(len(self)):
651      yield UnknownFieldRef(self, i)
652
653  def _extend(self, other):
654    if other is None:
655      return
656    # pylint: disable=protected-access
657    self._values.extend(other._values)
658
659  def __eq__(self, other):
660    if self is other:
661      return True
662    # Sort unknown fields because their order shouldn't
663    # affect equality test.
664    values = list(self._values)
665    if other is None:
666      return not values
667    values.sort()
668    # pylint: disable=protected-access
669    other_values = sorted(other._values)
670    return values == other_values
671
672  def _clear(self):
673    for value in self._values:
674      # pylint: disable=protected-access
675      if isinstance(value._data, UnknownFieldSet):
676        value._data._clear()  # pylint: disable=protected-access
677    self._values = None
678