• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Provides type checking routines.
9
10This module defines type checking utilities in the forms of dictionaries:
11
12VALUE_CHECKERS: A dictionary of field types and a value validation object.
13TYPE_TO_BYTE_SIZE_FN: A dictionary with field types and a size computing
14  function.
15TYPE_TO_SERIALIZE_METHOD: A dictionary with field types and serialization
16  function.
17FIELD_TYPE_TO_WIRE_TYPE: A dictionary with field typed and their
18  corresponding wire types.
19TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization
20  function.
21"""
22
23__author__ = 'robinson@google.com (Will Robinson)'
24
25import struct
26import numbers
27
28from google.protobuf.internal import decoder
29from google.protobuf.internal import encoder
30from google.protobuf.internal import wire_format
31from google.protobuf import descriptor
32
33_FieldDescriptor = descriptor.FieldDescriptor
34
35
36def TruncateToFourByteFloat(original):
37  return struct.unpack('<f', struct.pack('<f', original))[0]
38
39
40def ToShortestFloat(original):
41  """Returns the shortest float that has same value in wire."""
42  # All 4 byte floats have between 6 and 9 significant digits, so we
43  # start with 6 as the lower bound.
44  # It has to be iterative because use '.9g' directly can not get rid
45  # of the noises for most values. For example if set a float_field=0.9
46  # use '.9g' will print 0.899999976.
47  precision = 6
48  rounded = float('{0:.{1}g}'.format(original, precision))
49  while TruncateToFourByteFloat(rounded) != original:
50    precision += 1
51    rounded = float('{0:.{1}g}'.format(original, precision))
52  return rounded
53
54
55def GetTypeChecker(field):
56  """Returns a type checker for a message field of the specified types.
57
58  Args:
59    field: FieldDescriptor object for this field.
60
61  Returns:
62    An instance of TypeChecker which can be used to verify the types
63    of values assigned to a field of the specified type.
64  """
65  if (field.cpp_type == _FieldDescriptor.CPPTYPE_STRING and
66      field.type == _FieldDescriptor.TYPE_STRING):
67    return UnicodeValueChecker()
68  if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
69    if field.enum_type.is_closed:
70      return EnumValueChecker(field.enum_type)
71    else:
72      # When open enums are supported, any int32 can be assigned.
73      return _VALUE_CHECKERS[_FieldDescriptor.CPPTYPE_INT32]
74  return _VALUE_CHECKERS[field.cpp_type]
75
76
77# None of the typecheckers below make any attempt to guard against people
78# subclassing builtin types and doing weird things.  We're not trying to
79# protect against malicious clients here, just people accidentally shooting
80# themselves in the foot in obvious ways.
81class TypeChecker(object):
82
83  """Type checker used to catch type errors as early as possible
84  when the client is setting scalar fields in protocol messages.
85  """
86
87  def __init__(self, *acceptable_types):
88    self._acceptable_types = acceptable_types
89
90  def CheckValue(self, proposed_value):
91    """Type check the provided value and return it.
92
93    The returned value might have been normalized to another type.
94    """
95    if not isinstance(proposed_value, self._acceptable_types):
96      message = ('%.1024r has type %s, but expected one of: %s' %
97                 (proposed_value, type(proposed_value), self._acceptable_types))
98      raise TypeError(message)
99    return proposed_value
100
101
102class TypeCheckerWithDefault(TypeChecker):
103
104  def __init__(self, default_value, *acceptable_types):
105    TypeChecker.__init__(self, *acceptable_types)
106    self._default_value = default_value
107
108  def DefaultValue(self):
109    return self._default_value
110
111
112class BoolValueChecker(object):
113  """Type checker used for bool fields."""
114
115  def CheckValue(self, proposed_value):
116    if not hasattr(proposed_value, '__index__') or (
117        type(proposed_value).__module__ == 'numpy' and
118        type(proposed_value).__name__ == 'ndarray'):
119      message = ('%.1024r has type %s, but expected one of: %s' %
120                 (proposed_value, type(proposed_value), (bool, int)))
121      raise TypeError(message)
122    return bool(proposed_value)
123
124  def DefaultValue(self):
125    return False
126
127
128# IntValueChecker and its subclasses perform integer type-checks
129# and bounds-checks.
130class IntValueChecker(object):
131
132  """Checker used for integer fields.  Performs type-check and range check."""
133
134  def CheckValue(self, proposed_value):
135    if not hasattr(proposed_value, '__index__') or (
136        type(proposed_value).__module__ == 'numpy' and
137        type(proposed_value).__name__ == 'ndarray'):
138      message = ('%.1024r has type %s, but expected one of: %s' %
139                 (proposed_value, type(proposed_value), (int,)))
140      raise TypeError(message)
141
142    if not self._MIN <= int(proposed_value) <= self._MAX:
143      raise ValueError('Value out of range: %d' % proposed_value)
144    # We force all values to int to make alternate implementations where the
145    # distinction is more significant (e.g. the C++ implementation) simpler.
146    proposed_value = int(proposed_value)
147    return proposed_value
148
149  def DefaultValue(self):
150    return 0
151
152
153class EnumValueChecker(object):
154
155  """Checker used for enum fields.  Performs type-check and range check."""
156
157  def __init__(self, enum_type):
158    self._enum_type = enum_type
159
160  def CheckValue(self, proposed_value):
161    if not isinstance(proposed_value, numbers.Integral):
162      message = ('%.1024r has type %s, but expected one of: %s' %
163                 (proposed_value, type(proposed_value), (int,)))
164      raise TypeError(message)
165    if int(proposed_value) not in self._enum_type.values_by_number:
166      raise ValueError('Unknown enum value: %d' % proposed_value)
167    return proposed_value
168
169  def DefaultValue(self):
170    return self._enum_type.values[0].number
171
172
173class UnicodeValueChecker(object):
174
175  """Checker used for string fields.
176
177  Always returns a unicode value, even if the input is of type str.
178  """
179
180  def CheckValue(self, proposed_value):
181    if not isinstance(proposed_value, (bytes, str)):
182      message = ('%.1024r has type %s, but expected one of: %s' %
183                 (proposed_value, type(proposed_value), (bytes, str)))
184      raise TypeError(message)
185
186    # If the value is of type 'bytes' make sure that it is valid UTF-8 data.
187    if isinstance(proposed_value, bytes):
188      try:
189        proposed_value = proposed_value.decode('utf-8')
190      except UnicodeDecodeError:
191        raise ValueError('%.1024r has type bytes, but isn\'t valid UTF-8 '
192                         'encoding. Non-UTF-8 strings must be converted to '
193                         'unicode objects before being added.' %
194                         (proposed_value))
195    else:
196      try:
197        proposed_value.encode('utf8')
198      except UnicodeEncodeError:
199        raise ValueError('%.1024r isn\'t a valid unicode string and '
200                         'can\'t be encoded in UTF-8.'%
201                         (proposed_value))
202
203    return proposed_value
204
205  def DefaultValue(self):
206    return u""
207
208
209class Int32ValueChecker(IntValueChecker):
210  # We're sure to use ints instead of longs here since comparison may be more
211  # efficient.
212  _MIN = -2147483648
213  _MAX = 2147483647
214
215
216class Uint32ValueChecker(IntValueChecker):
217  _MIN = 0
218  _MAX = (1 << 32) - 1
219
220
221class Int64ValueChecker(IntValueChecker):
222  _MIN = -(1 << 63)
223  _MAX = (1 << 63) - 1
224
225
226class Uint64ValueChecker(IntValueChecker):
227  _MIN = 0
228  _MAX = (1 << 64) - 1
229
230
231# The max 4 bytes float is about 3.4028234663852886e+38
232_FLOAT_MAX = float.fromhex('0x1.fffffep+127')
233_FLOAT_MIN = -_FLOAT_MAX
234_INF = float('inf')
235_NEG_INF = float('-inf')
236
237
238class DoubleValueChecker(object):
239  """Checker used for double fields.
240
241  Performs type-check and range check.
242  """
243
244  def CheckValue(self, proposed_value):
245    """Check and convert proposed_value to float."""
246    if (not hasattr(proposed_value, '__float__') and
247        not hasattr(proposed_value, '__index__')) or (
248            type(proposed_value).__module__ == 'numpy' and
249            type(proposed_value).__name__ == 'ndarray'):
250      message = ('%.1024r has type %s, but expected one of: int, float' %
251                 (proposed_value, type(proposed_value)))
252      raise TypeError(message)
253    return float(proposed_value)
254
255  def DefaultValue(self):
256    return 0.0
257
258
259class FloatValueChecker(DoubleValueChecker):
260  """Checker used for float fields.
261
262  Performs type-check and range check.
263
264  Values exceeding a 32-bit float will be converted to inf/-inf.
265  """
266
267  def CheckValue(self, proposed_value):
268    """Check and convert proposed_value to float."""
269    converted_value = super().CheckValue(proposed_value)
270    # This inf rounding matches the C++ proto SafeDoubleToFloat logic.
271    if converted_value > _FLOAT_MAX:
272      return _INF
273    if converted_value < _FLOAT_MIN:
274      return _NEG_INF
275
276    return TruncateToFourByteFloat(converted_value)
277
278# Type-checkers for all scalar CPPTYPEs.
279_VALUE_CHECKERS = {
280    _FieldDescriptor.CPPTYPE_INT32: Int32ValueChecker(),
281    _FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(),
282    _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(),
283    _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(),
284    _FieldDescriptor.CPPTYPE_DOUBLE: DoubleValueChecker(),
285    _FieldDescriptor.CPPTYPE_FLOAT: FloatValueChecker(),
286    _FieldDescriptor.CPPTYPE_BOOL: BoolValueChecker(),
287    _FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes),
288}
289
290
291# Map from field type to a function F, such that F(field_num, value)
292# gives the total byte size for a value of the given type.  This
293# byte size includes tag information and any other additional space
294# associated with serializing "value".
295TYPE_TO_BYTE_SIZE_FN = {
296    _FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize,
297    _FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize,
298    _FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize,
299    _FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize,
300    _FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize,
301    _FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize,
302    _FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize,
303    _FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize,
304    _FieldDescriptor.TYPE_STRING: wire_format.StringByteSize,
305    _FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize,
306    _FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize,
307    _FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize,
308    _FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize,
309    _FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize,
310    _FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize,
311    _FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize,
312    _FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize,
313    _FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize
314    }
315
316
317# Maps from field types to encoder constructors.
318TYPE_TO_ENCODER = {
319    _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder,
320    _FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder,
321    _FieldDescriptor.TYPE_INT64: encoder.Int64Encoder,
322    _FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder,
323    _FieldDescriptor.TYPE_INT32: encoder.Int32Encoder,
324    _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder,
325    _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder,
326    _FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder,
327    _FieldDescriptor.TYPE_STRING: encoder.StringEncoder,
328    _FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder,
329    _FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder,
330    _FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder,
331    _FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder,
332    _FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder,
333    _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder,
334    _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder,
335    _FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder,
336    _FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder,
337    }
338
339
340# Maps from field types to sizer constructors.
341TYPE_TO_SIZER = {
342    _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer,
343    _FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer,
344    _FieldDescriptor.TYPE_INT64: encoder.Int64Sizer,
345    _FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer,
346    _FieldDescriptor.TYPE_INT32: encoder.Int32Sizer,
347    _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer,
348    _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer,
349    _FieldDescriptor.TYPE_BOOL: encoder.BoolSizer,
350    _FieldDescriptor.TYPE_STRING: encoder.StringSizer,
351    _FieldDescriptor.TYPE_GROUP: encoder.GroupSizer,
352    _FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer,
353    _FieldDescriptor.TYPE_BYTES: encoder.BytesSizer,
354    _FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer,
355    _FieldDescriptor.TYPE_ENUM: encoder.EnumSizer,
356    _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer,
357    _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer,
358    _FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer,
359    _FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer,
360    }
361
362
363# Maps from field type to a decoder constructor.
364TYPE_TO_DECODER = {
365    _FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder,
366    _FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder,
367    _FieldDescriptor.TYPE_INT64: decoder.Int64Decoder,
368    _FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder,
369    _FieldDescriptor.TYPE_INT32: decoder.Int32Decoder,
370    _FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder,
371    _FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder,
372    _FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder,
373    _FieldDescriptor.TYPE_STRING: decoder.StringDecoder,
374    _FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder,
375    _FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder,
376    _FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder,
377    _FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder,
378    _FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder,
379    _FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder,
380    _FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder,
381    _FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder,
382    _FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder,
383    }
384
385# Maps from field type to expected wiretype.
386FIELD_TYPE_TO_WIRE_TYPE = {
387    _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64,
388    _FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32,
389    _FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT,
390    _FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT,
391    _FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT,
392    _FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64,
393    _FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32,
394    _FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT,
395    _FieldDescriptor.TYPE_STRING:
396      wire_format.WIRETYPE_LENGTH_DELIMITED,
397    _FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP,
398    _FieldDescriptor.TYPE_MESSAGE:
399      wire_format.WIRETYPE_LENGTH_DELIMITED,
400    _FieldDescriptor.TYPE_BYTES:
401      wire_format.WIRETYPE_LENGTH_DELIMITED,
402    _FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT,
403    _FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT,
404    _FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32,
405    _FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64,
406    _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT,
407    _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT,
408    }
409