1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# 4# Use of this source code is governed by a BSD-style 5# license that can be found in the LICENSE file or at 6# https://developers.google.com/open-source/licenses/bsd 7 8"""Provides type checking routines. 9 10This module defines type checking utilities in the forms of dictionaries: 11 12VALUE_CHECKERS: A dictionary of field types and a value validation object. 13TYPE_TO_BYTE_SIZE_FN: A dictionary with field types and a size computing 14 function. 15TYPE_TO_SERIALIZE_METHOD: A dictionary with field types and serialization 16 function. 17FIELD_TYPE_TO_WIRE_TYPE: A dictionary with field typed and their 18 corresponding wire types. 19TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization 20 function. 21""" 22 23__author__ = 'robinson@google.com (Will Robinson)' 24 25import struct 26import numbers 27 28from google.protobuf.internal import decoder 29from google.protobuf.internal import encoder 30from google.protobuf.internal import wire_format 31from google.protobuf import descriptor 32 33_FieldDescriptor = descriptor.FieldDescriptor 34 35 36def TruncateToFourByteFloat(original): 37 return struct.unpack('<f', struct.pack('<f', original))[0] 38 39 40def ToShortestFloat(original): 41 """Returns the shortest float that has same value in wire.""" 42 # All 4 byte floats have between 6 and 9 significant digits, so we 43 # start with 6 as the lower bound. 44 # It has to be iterative because use '.9g' directly can not get rid 45 # of the noises for most values. For example if set a float_field=0.9 46 # use '.9g' will print 0.899999976. 47 precision = 6 48 rounded = float('{0:.{1}g}'.format(original, precision)) 49 while TruncateToFourByteFloat(rounded) != original: 50 precision += 1 51 rounded = float('{0:.{1}g}'.format(original, precision)) 52 return rounded 53 54 55def GetTypeChecker(field): 56 """Returns a type checker for a message field of the specified types. 57 58 Args: 59 field: FieldDescriptor object for this field. 60 61 Returns: 62 An instance of TypeChecker which can be used to verify the types 63 of values assigned to a field of the specified type. 64 """ 65 if (field.cpp_type == _FieldDescriptor.CPPTYPE_STRING and 66 field.type == _FieldDescriptor.TYPE_STRING): 67 return UnicodeValueChecker() 68 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 69 if field.enum_type.is_closed: 70 return EnumValueChecker(field.enum_type) 71 else: 72 # When open enums are supported, any int32 can be assigned. 73 return _VALUE_CHECKERS[_FieldDescriptor.CPPTYPE_INT32] 74 return _VALUE_CHECKERS[field.cpp_type] 75 76 77# None of the typecheckers below make any attempt to guard against people 78# subclassing builtin types and doing weird things. We're not trying to 79# protect against malicious clients here, just people accidentally shooting 80# themselves in the foot in obvious ways. 81class TypeChecker(object): 82 83 """Type checker used to catch type errors as early as possible 84 when the client is setting scalar fields in protocol messages. 85 """ 86 87 def __init__(self, *acceptable_types): 88 self._acceptable_types = acceptable_types 89 90 def CheckValue(self, proposed_value): 91 """Type check the provided value and return it. 92 93 The returned value might have been normalized to another type. 94 """ 95 if not isinstance(proposed_value, self._acceptable_types): 96 message = ('%.1024r has type %s, but expected one of: %s' % 97 (proposed_value, type(proposed_value), self._acceptable_types)) 98 raise TypeError(message) 99 return proposed_value 100 101 102class TypeCheckerWithDefault(TypeChecker): 103 104 def __init__(self, default_value, *acceptable_types): 105 TypeChecker.__init__(self, *acceptable_types) 106 self._default_value = default_value 107 108 def DefaultValue(self): 109 return self._default_value 110 111 112class BoolValueChecker(object): 113 """Type checker used for bool fields.""" 114 115 def CheckValue(self, proposed_value): 116 if not hasattr(proposed_value, '__index__') or ( 117 type(proposed_value).__module__ == 'numpy' and 118 type(proposed_value).__name__ == 'ndarray'): 119 message = ('%.1024r has type %s, but expected one of: %s' % 120 (proposed_value, type(proposed_value), (bool, int))) 121 raise TypeError(message) 122 return bool(proposed_value) 123 124 def DefaultValue(self): 125 return False 126 127 128# IntValueChecker and its subclasses perform integer type-checks 129# and bounds-checks. 130class IntValueChecker(object): 131 132 """Checker used for integer fields. Performs type-check and range check.""" 133 134 def CheckValue(self, proposed_value): 135 if not hasattr(proposed_value, '__index__') or ( 136 type(proposed_value).__module__ == 'numpy' and 137 type(proposed_value).__name__ == 'ndarray'): 138 message = ('%.1024r has type %s, but expected one of: %s' % 139 (proposed_value, type(proposed_value), (int,))) 140 raise TypeError(message) 141 142 if not self._MIN <= int(proposed_value) <= self._MAX: 143 raise ValueError('Value out of range: %d' % proposed_value) 144 # We force all values to int to make alternate implementations where the 145 # distinction is more significant (e.g. the C++ implementation) simpler. 146 proposed_value = int(proposed_value) 147 return proposed_value 148 149 def DefaultValue(self): 150 return 0 151 152 153class EnumValueChecker(object): 154 155 """Checker used for enum fields. Performs type-check and range check.""" 156 157 def __init__(self, enum_type): 158 self._enum_type = enum_type 159 160 def CheckValue(self, proposed_value): 161 if not isinstance(proposed_value, numbers.Integral): 162 message = ('%.1024r has type %s, but expected one of: %s' % 163 (proposed_value, type(proposed_value), (int,))) 164 raise TypeError(message) 165 if int(proposed_value) not in self._enum_type.values_by_number: 166 raise ValueError('Unknown enum value: %d' % proposed_value) 167 return proposed_value 168 169 def DefaultValue(self): 170 return self._enum_type.values[0].number 171 172 173class UnicodeValueChecker(object): 174 175 """Checker used for string fields. 176 177 Always returns a unicode value, even if the input is of type str. 178 """ 179 180 def CheckValue(self, proposed_value): 181 if not isinstance(proposed_value, (bytes, str)): 182 message = ('%.1024r has type %s, but expected one of: %s' % 183 (proposed_value, type(proposed_value), (bytes, str))) 184 raise TypeError(message) 185 186 # If the value is of type 'bytes' make sure that it is valid UTF-8 data. 187 if isinstance(proposed_value, bytes): 188 try: 189 proposed_value = proposed_value.decode('utf-8') 190 except UnicodeDecodeError: 191 raise ValueError('%.1024r has type bytes, but isn\'t valid UTF-8 ' 192 'encoding. Non-UTF-8 strings must be converted to ' 193 'unicode objects before being added.' % 194 (proposed_value)) 195 else: 196 try: 197 proposed_value.encode('utf8') 198 except UnicodeEncodeError: 199 raise ValueError('%.1024r isn\'t a valid unicode string and ' 200 'can\'t be encoded in UTF-8.'% 201 (proposed_value)) 202 203 return proposed_value 204 205 def DefaultValue(self): 206 return u"" 207 208 209class Int32ValueChecker(IntValueChecker): 210 # We're sure to use ints instead of longs here since comparison may be more 211 # efficient. 212 _MIN = -2147483648 213 _MAX = 2147483647 214 215 216class Uint32ValueChecker(IntValueChecker): 217 _MIN = 0 218 _MAX = (1 << 32) - 1 219 220 221class Int64ValueChecker(IntValueChecker): 222 _MIN = -(1 << 63) 223 _MAX = (1 << 63) - 1 224 225 226class Uint64ValueChecker(IntValueChecker): 227 _MIN = 0 228 _MAX = (1 << 64) - 1 229 230 231# The max 4 bytes float is about 3.4028234663852886e+38 232_FLOAT_MAX = float.fromhex('0x1.fffffep+127') 233_FLOAT_MIN = -_FLOAT_MAX 234_INF = float('inf') 235_NEG_INF = float('-inf') 236 237 238class DoubleValueChecker(object): 239 """Checker used for double fields. 240 241 Performs type-check and range check. 242 """ 243 244 def CheckValue(self, proposed_value): 245 """Check and convert proposed_value to float.""" 246 if (not hasattr(proposed_value, '__float__') and 247 not hasattr(proposed_value, '__index__')) or ( 248 type(proposed_value).__module__ == 'numpy' and 249 type(proposed_value).__name__ == 'ndarray'): 250 message = ('%.1024r has type %s, but expected one of: int, float' % 251 (proposed_value, type(proposed_value))) 252 raise TypeError(message) 253 return float(proposed_value) 254 255 def DefaultValue(self): 256 return 0.0 257 258 259class FloatValueChecker(DoubleValueChecker): 260 """Checker used for float fields. 261 262 Performs type-check and range check. 263 264 Values exceeding a 32-bit float will be converted to inf/-inf. 265 """ 266 267 def CheckValue(self, proposed_value): 268 """Check and convert proposed_value to float.""" 269 converted_value = super().CheckValue(proposed_value) 270 # This inf rounding matches the C++ proto SafeDoubleToFloat logic. 271 if converted_value > _FLOAT_MAX: 272 return _INF 273 if converted_value < _FLOAT_MIN: 274 return _NEG_INF 275 276 return TruncateToFourByteFloat(converted_value) 277 278# Type-checkers for all scalar CPPTYPEs. 279_VALUE_CHECKERS = { 280 _FieldDescriptor.CPPTYPE_INT32: Int32ValueChecker(), 281 _FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(), 282 _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(), 283 _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(), 284 _FieldDescriptor.CPPTYPE_DOUBLE: DoubleValueChecker(), 285 _FieldDescriptor.CPPTYPE_FLOAT: FloatValueChecker(), 286 _FieldDescriptor.CPPTYPE_BOOL: BoolValueChecker(), 287 _FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes), 288} 289 290 291# Map from field type to a function F, such that F(field_num, value) 292# gives the total byte size for a value of the given type. This 293# byte size includes tag information and any other additional space 294# associated with serializing "value". 295TYPE_TO_BYTE_SIZE_FN = { 296 _FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize, 297 _FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize, 298 _FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize, 299 _FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize, 300 _FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize, 301 _FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize, 302 _FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize, 303 _FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize, 304 _FieldDescriptor.TYPE_STRING: wire_format.StringByteSize, 305 _FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize, 306 _FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize, 307 _FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize, 308 _FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize, 309 _FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize, 310 _FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize, 311 _FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize, 312 _FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize, 313 _FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize 314 } 315 316 317# Maps from field types to encoder constructors. 318TYPE_TO_ENCODER = { 319 _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder, 320 _FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder, 321 _FieldDescriptor.TYPE_INT64: encoder.Int64Encoder, 322 _FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder, 323 _FieldDescriptor.TYPE_INT32: encoder.Int32Encoder, 324 _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder, 325 _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder, 326 _FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder, 327 _FieldDescriptor.TYPE_STRING: encoder.StringEncoder, 328 _FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder, 329 _FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder, 330 _FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder, 331 _FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder, 332 _FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder, 333 _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder, 334 _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder, 335 _FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder, 336 _FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder, 337 } 338 339 340# Maps from field types to sizer constructors. 341TYPE_TO_SIZER = { 342 _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer, 343 _FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer, 344 _FieldDescriptor.TYPE_INT64: encoder.Int64Sizer, 345 _FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer, 346 _FieldDescriptor.TYPE_INT32: encoder.Int32Sizer, 347 _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer, 348 _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer, 349 _FieldDescriptor.TYPE_BOOL: encoder.BoolSizer, 350 _FieldDescriptor.TYPE_STRING: encoder.StringSizer, 351 _FieldDescriptor.TYPE_GROUP: encoder.GroupSizer, 352 _FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer, 353 _FieldDescriptor.TYPE_BYTES: encoder.BytesSizer, 354 _FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer, 355 _FieldDescriptor.TYPE_ENUM: encoder.EnumSizer, 356 _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer, 357 _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer, 358 _FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer, 359 _FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer, 360 } 361 362 363# Maps from field type to a decoder constructor. 364TYPE_TO_DECODER = { 365 _FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder, 366 _FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder, 367 _FieldDescriptor.TYPE_INT64: decoder.Int64Decoder, 368 _FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder, 369 _FieldDescriptor.TYPE_INT32: decoder.Int32Decoder, 370 _FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder, 371 _FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder, 372 _FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder, 373 _FieldDescriptor.TYPE_STRING: decoder.StringDecoder, 374 _FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder, 375 _FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder, 376 _FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder, 377 _FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder, 378 _FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder, 379 _FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder, 380 _FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder, 381 _FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder, 382 _FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder, 383 } 384 385# Maps from field type to expected wiretype. 386FIELD_TYPE_TO_WIRE_TYPE = { 387 _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64, 388 _FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32, 389 _FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT, 390 _FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT, 391 _FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT, 392 _FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64, 393 _FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32, 394 _FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT, 395 _FieldDescriptor.TYPE_STRING: 396 wire_format.WIRETYPE_LENGTH_DELIMITED, 397 _FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP, 398 _FieldDescriptor.TYPE_MESSAGE: 399 wire_format.WIRETYPE_LENGTH_DELIMITED, 400 _FieldDescriptor.TYPE_BYTES: 401 wire_format.WIRETYPE_LENGTH_DELIMITED, 402 _FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT, 403 _FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT, 404 _FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32, 405 _FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64, 406 _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT, 407 _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT, 408 } 409