1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library of dtypes (Tensor element types).""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21from six.moves import builtins 22 23from tensorflow.core.framework import types_pb2 24from tensorflow.python import pywrap_tensorflow 25from tensorflow.python.util.tf_export import tf_export 26 27_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type() 28 29 30@tf_export("dtypes.DType", "DType") 31class DType(object): 32 """Represents the type of the elements in a `Tensor`. 33 34 The following `DType` objects are defined: 35 36 * `tf.float16`: 16-bit half-precision floating-point. 37 * `tf.float32`: 32-bit single-precision floating-point. 38 * `tf.float64`: 64-bit double-precision floating-point. 39 * `tf.bfloat16`: 16-bit truncated floating-point. 40 * `tf.complex64`: 64-bit single-precision complex. 41 * `tf.complex128`: 128-bit double-precision complex. 42 * `tf.int8`: 8-bit signed integer. 43 * `tf.uint8`: 8-bit unsigned integer. 44 * `tf.uint16`: 16-bit unsigned integer. 45 * `tf.uint32`: 32-bit unsigned integer. 46 * `tf.uint64`: 64-bit unsigned integer. 47 * `tf.int16`: 16-bit signed integer. 48 * `tf.int32`: 32-bit signed integer. 49 * `tf.int64`: 64-bit signed integer. 50 * `tf.bool`: Boolean. 51 * `tf.string`: String. 52 * `tf.qint8`: Quantized 8-bit signed integer. 53 * `tf.quint8`: Quantized 8-bit unsigned integer. 54 * `tf.qint16`: Quantized 16-bit signed integer. 55 * `tf.quint16`: Quantized 16-bit unsigned integer. 56 * `tf.qint32`: Quantized 32-bit signed integer. 57 * `tf.resource`: Handle to a mutable resource. 58 * `tf.variant`: Values of arbitrary types. 59 60 In addition, variants of these types with the `_ref` suffix are 61 defined for reference-typed tensors. 62 63 The `tf.as_dtype()` function converts numpy types and string type 64 names to a `DType` object. 65 """ 66 67 def __init__(self, type_enum): 68 """Creates a new `DataType`. 69 70 NOTE(mrry): In normal circumstances, you should not need to 71 construct a `DataType` object directly. Instead, use the 72 `tf.as_dtype()` function. 73 74 Args: 75 type_enum: A `types_pb2.DataType` enum value. 76 77 Raises: 78 TypeError: If `type_enum` is not a value `types_pb2.DataType`. 79 80 """ 81 # TODO(mrry): Make the necessary changes (using __new__) to ensure 82 # that calling this returns one of the interned values. 83 type_enum = int(type_enum) 84 if (type_enum not in types_pb2.DataType.values() or 85 type_enum == types_pb2.DT_INVALID): 86 raise TypeError( 87 "type_enum is not a valid types_pb2.DataType: %s" % type_enum) 88 self._type_enum = type_enum 89 90 @property 91 def _is_ref_dtype(self): 92 """Returns `True` if this `DType` represents a reference type.""" 93 return self._type_enum > 100 94 95 @property 96 def _as_ref(self): 97 """Returns a reference `DType` based on this `DType`.""" 98 if self._is_ref_dtype: 99 return self 100 else: 101 return _INTERN_TABLE[self._type_enum + 100] 102 103 @property 104 def base_dtype(self): 105 """Returns a non-reference `DType` based on this `DType`.""" 106 if self._is_ref_dtype: 107 return _INTERN_TABLE[self._type_enum - 100] 108 else: 109 return self 110 111 @property 112 def real_dtype(self): 113 """Returns the dtype correspond to this dtype's real part.""" 114 base = self.base_dtype 115 if base == complex64: 116 return float32 117 elif base == complex128: 118 return float64 119 else: 120 return self 121 122 @property 123 def is_numpy_compatible(self): 124 return self._type_enum not in _NUMPY_INCOMPATIBLE 125 126 @property 127 def as_numpy_dtype(self): 128 """Returns a `numpy.dtype` based on this `DType`.""" 129 return _TF_TO_NP[self._type_enum] 130 131 @property 132 def as_datatype_enum(self): 133 """Returns a `types_pb2.DataType` enum value based on this `DType`.""" 134 return self._type_enum 135 136 @property 137 def is_bool(self): 138 """Returns whether this is a boolean data type""" 139 return self.base_dtype == bool 140 141 @property 142 def is_integer(self): 143 """Returns whether this is a (non-quantized) integer type.""" 144 return (self.is_numpy_compatible and not self.is_quantized and 145 np.issubdtype(self.as_numpy_dtype, np.integer)) 146 147 @property 148 def is_floating(self): 149 """Returns whether this is a (non-quantized, real) floating point type.""" 150 return ((self.is_numpy_compatible and 151 np.issubdtype(self.as_numpy_dtype, np.floating)) or 152 self.base_dtype == bfloat16) 153 154 @property 155 def is_complex(self): 156 """Returns whether this is a complex floating point type.""" 157 return self.base_dtype in (complex64, complex128) 158 159 @property 160 def is_quantized(self): 161 """Returns whether this is a quantized data type.""" 162 return self.base_dtype in _QUANTIZED_DTYPES_NO_REF 163 164 @property 165 def is_unsigned(self): 166 """Returns whether this type is unsigned. 167 168 Non-numeric, unordered, and quantized types are not considered unsigned, and 169 this function returns `False`. 170 171 Returns: 172 Whether a `DType` is unsigned. 173 """ 174 try: 175 return self.min == 0 176 except TypeError: 177 return False 178 179 @property 180 def min(self): 181 """Returns the minimum representable value in this data type. 182 183 Raises: 184 TypeError: if this is a non-numeric, unordered, or quantized type. 185 186 """ 187 if (self.is_quantized or 188 self.base_dtype in (bool, string, complex64, complex128)): 189 raise TypeError("Cannot find minimum value of %s." % self) 190 191 # there is no simple way to get the min value of a dtype, we have to check 192 # float and int types separately 193 try: 194 return np.finfo(self.as_numpy_dtype()).min 195 except: # bare except as possible raises by finfo not documented 196 try: 197 return np.iinfo(self.as_numpy_dtype()).min 198 except: 199 if self.base_dtype == bfloat16: 200 return _np_bfloat16(float.fromhex("-0x1.FEp127")) 201 raise TypeError("Cannot find minimum value of %s." % self) 202 203 @property 204 def max(self): 205 """Returns the maximum representable value in this data type. 206 207 Raises: 208 TypeError: if this is a non-numeric, unordered, or quantized type. 209 210 """ 211 if (self.is_quantized or 212 self.base_dtype in (bool, string, complex64, complex128)): 213 raise TypeError("Cannot find maximum value of %s." % self) 214 215 # there is no simple way to get the max value of a dtype, we have to check 216 # float and int types separately 217 try: 218 return np.finfo(self.as_numpy_dtype()).max 219 except: # bare except as possible raises by finfo not documented 220 try: 221 return np.iinfo(self.as_numpy_dtype()).max 222 except: 223 if self.base_dtype == bfloat16: 224 return _np_bfloat16(float.fromhex("0x1.FEp127")) 225 raise TypeError("Cannot find maximum value of %s." % self) 226 227 @property 228 def limits(self, clip_negative=True): 229 """Return intensity limits, i.e. (min, max) tuple, of the dtype. 230 Args: 231 clip_negative : bool, optional 232 If True, clip the negative range (i.e. return 0 for min intensity) 233 even if the image dtype allows negative values. 234 Returns 235 min, max : tuple 236 Lower and upper intensity limits. 237 """ 238 min, max = dtype_range[self.as_numpy_dtype] # pylint: disable=redefined-builtin 239 if clip_negative: 240 min = 0 # pylint: disable=redefined-builtin 241 return min, max 242 243 def is_compatible_with(self, other): 244 """Returns True if the `other` DType will be converted to this DType. 245 246 The conversion rules are as follows: 247 248 ```python 249 DType(T) .is_compatible_with(DType(T)) == True 250 DType(T) .is_compatible_with(DType(T).as_ref) == True 251 DType(T).as_ref.is_compatible_with(DType(T)) == False 252 DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True 253 ``` 254 255 Args: 256 other: A `DType` (or object that may be converted to a `DType`). 257 258 Returns: 259 True if a Tensor of the `other` `DType` will be implicitly converted to 260 this `DType`. 261 """ 262 other = as_dtype(other) 263 return self._type_enum in (other.as_datatype_enum, 264 other.base_dtype.as_datatype_enum) 265 266 def __eq__(self, other): 267 """Returns True iff this DType refers to the same type as `other`.""" 268 if other is None: 269 return False 270 try: 271 dtype = as_dtype(other).as_datatype_enum 272 return self._type_enum == dtype # pylint: disable=protected-access 273 except TypeError: 274 return False 275 276 def __ne__(self, other): 277 """Returns True iff self != other.""" 278 return not self.__eq__(other) 279 280 @property 281 def name(self): 282 """Returns the string name for this `DType`.""" 283 return _TYPE_TO_STRING[self._type_enum] 284 285 def __str__(self): 286 return "<dtype: %r>" % self.name 287 288 def __repr__(self): 289 return "tf." + self.name 290 291 def __hash__(self): 292 return self._type_enum 293 294 def __reduce__(self): 295 return as_dtype, (self.name,) 296 297 @property 298 def size(self): 299 if (self._type_enum == types_pb2.DT_VARIANT or 300 self._type_enum == types_pb2.DT_RESOURCE): 301 return 1 302 return np.dtype(self.as_numpy_dtype).itemsize 303 304 305# Define data type range of numpy dtype 306dtype_range = { 307 np.bool_: (False, True), 308 np.bool8: (False, True), 309 np.uint8: (0, 255), 310 np.uint16: (0, 65535), 311 np.int8: (-128, 127), 312 np.int16: (-32768, 32767), 313 np.int64: (-2**63, 2**63 - 1), 314 np.uint64: (0, 2**64 - 1), 315 np.int32: (-2**31, 2**31 - 1), 316 np.uint32: (0, 2**32 - 1), 317 np.float32: (-1, 1), 318 np.float64: (-1, 1) 319} 320 321# Define standard wrappers for the types_pb2.DataType enum. 322resource = DType(types_pb2.DT_RESOURCE) 323tf_export("dtypes.resource", "resource").export_constant(__name__, "resource") 324variant = DType(types_pb2.DT_VARIANT) 325tf_export("dtypes.variant", "variant").export_constant(__name__, "variant") 326float16 = DType(types_pb2.DT_HALF) 327tf_export("dtypes.float16", "float16").export_constant(__name__, "float16") 328half = float16 329tf_export("dtypes.half", "half").export_constant(__name__, "half") 330float32 = DType(types_pb2.DT_FLOAT) 331tf_export("dtypes.float32", "float32").export_constant(__name__, "float32") 332float64 = DType(types_pb2.DT_DOUBLE) 333tf_export("dtypes.float64", "float64").export_constant(__name__, "float64") 334double = float64 335tf_export("dtypes.double", "double").export_constant(__name__, "double") 336int32 = DType(types_pb2.DT_INT32) 337tf_export("dtypes.int32", "int32").export_constant(__name__, "int32") 338uint8 = DType(types_pb2.DT_UINT8) 339tf_export("dtypes.uint8", "uint8").export_constant(__name__, "uint8") 340uint16 = DType(types_pb2.DT_UINT16) 341tf_export("dtypes.uint16", "uint16").export_constant(__name__, "uint16") 342uint32 = DType(types_pb2.DT_UINT32) 343tf_export("dtypes.uint32", "uint32").export_constant(__name__, "uint32") 344uint64 = DType(types_pb2.DT_UINT64) 345tf_export("dtypes.uint64", "uint64").export_constant(__name__, "uint64") 346int16 = DType(types_pb2.DT_INT16) 347tf_export("dtypes.int16", "int16").export_constant(__name__, "int16") 348int8 = DType(types_pb2.DT_INT8) 349tf_export("dtypes.int8", "int8").export_constant(__name__, "int8") 350string = DType(types_pb2.DT_STRING) 351tf_export("dtypes.string", "string").export_constant(__name__, "string") 352complex64 = DType(types_pb2.DT_COMPLEX64) 353tf_export("dtypes.complex64", "complex64").export_constant( 354 __name__, "complex64") 355complex128 = DType(types_pb2.DT_COMPLEX128) 356tf_export("dtypes.complex128", "complex128").export_constant( 357 __name__, "complex128") 358int64 = DType(types_pb2.DT_INT64) 359tf_export("dtypes.int64", "int64").export_constant(__name__, "int64") 360bool = DType(types_pb2.DT_BOOL) # pylint: disable=redefined-builtin 361tf_export("dtypes.bool", "bool").export_constant(__name__, "bool") 362qint8 = DType(types_pb2.DT_QINT8) 363tf_export("dtypes.qint8", "qint8").export_constant(__name__, "qint8") 364quint8 = DType(types_pb2.DT_QUINT8) 365tf_export("dtypes.quint8", "quint8").export_constant(__name__, "quint8") 366qint16 = DType(types_pb2.DT_QINT16) 367tf_export("dtypes.qint16", "qint16").export_constant(__name__, "qint16") 368quint16 = DType(types_pb2.DT_QUINT16) 369tf_export("dtypes.quint16", "quint16").export_constant(__name__, "quint16") 370qint32 = DType(types_pb2.DT_QINT32) 371tf_export("dtypes.qint32", "qint32").export_constant(__name__, "qint32") 372resource_ref = DType(types_pb2.DT_RESOURCE_REF) 373variant_ref = DType(types_pb2.DT_VARIANT_REF) 374bfloat16 = DType(types_pb2.DT_BFLOAT16) 375tf_export("dtypes.bfloat16", "bfloat16").export_constant(__name__, "bfloat16") 376float16_ref = DType(types_pb2.DT_HALF_REF) 377half_ref = float16_ref 378float32_ref = DType(types_pb2.DT_FLOAT_REF) 379float64_ref = DType(types_pb2.DT_DOUBLE_REF) 380double_ref = float64_ref 381int32_ref = DType(types_pb2.DT_INT32_REF) 382uint32_ref = DType(types_pb2.DT_UINT32_REF) 383uint8_ref = DType(types_pb2.DT_UINT8_REF) 384uint16_ref = DType(types_pb2.DT_UINT16_REF) 385int16_ref = DType(types_pb2.DT_INT16_REF) 386int8_ref = DType(types_pb2.DT_INT8_REF) 387string_ref = DType(types_pb2.DT_STRING_REF) 388complex64_ref = DType(types_pb2.DT_COMPLEX64_REF) 389complex128_ref = DType(types_pb2.DT_COMPLEX128_REF) 390int64_ref = DType(types_pb2.DT_INT64_REF) 391uint64_ref = DType(types_pb2.DT_UINT64_REF) 392bool_ref = DType(types_pb2.DT_BOOL_REF) 393qint8_ref = DType(types_pb2.DT_QINT8_REF) 394quint8_ref = DType(types_pb2.DT_QUINT8_REF) 395qint16_ref = DType(types_pb2.DT_QINT16_REF) 396quint16_ref = DType(types_pb2.DT_QUINT16_REF) 397qint32_ref = DType(types_pb2.DT_QINT32_REF) 398bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF) 399 400_NUMPY_INCOMPATIBLE = frozenset([ 401 types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE, 402 types_pb2.DT_RESOURCE_REF 403]) 404 405# Maintain an intern table so that we don't have to create a large 406# number of small objects. 407_INTERN_TABLE = { 408 types_pb2.DT_HALF: float16, 409 types_pb2.DT_FLOAT: float32, 410 types_pb2.DT_DOUBLE: float64, 411 types_pb2.DT_INT32: int32, 412 types_pb2.DT_UINT8: uint8, 413 types_pb2.DT_UINT16: uint16, 414 types_pb2.DT_UINT32: uint32, 415 types_pb2.DT_UINT64: uint64, 416 types_pb2.DT_INT16: int16, 417 types_pb2.DT_INT8: int8, 418 types_pb2.DT_STRING: string, 419 types_pb2.DT_COMPLEX64: complex64, 420 types_pb2.DT_COMPLEX128: complex128, 421 types_pb2.DT_INT64: int64, 422 types_pb2.DT_BOOL: bool, 423 types_pb2.DT_QINT8: qint8, 424 types_pb2.DT_QUINT8: quint8, 425 types_pb2.DT_QINT16: qint16, 426 types_pb2.DT_QUINT16: quint16, 427 types_pb2.DT_QINT32: qint32, 428 types_pb2.DT_BFLOAT16: bfloat16, 429 types_pb2.DT_RESOURCE: resource, 430 types_pb2.DT_VARIANT: variant, 431 types_pb2.DT_HALF_REF: float16_ref, 432 types_pb2.DT_FLOAT_REF: float32_ref, 433 types_pb2.DT_DOUBLE_REF: float64_ref, 434 types_pb2.DT_INT32_REF: int32_ref, 435 types_pb2.DT_UINT32_REF: uint32_ref, 436 types_pb2.DT_UINT8_REF: uint8_ref, 437 types_pb2.DT_UINT16_REF: uint16_ref, 438 types_pb2.DT_INT16_REF: int16_ref, 439 types_pb2.DT_INT8_REF: int8_ref, 440 types_pb2.DT_STRING_REF: string_ref, 441 types_pb2.DT_COMPLEX64_REF: complex64_ref, 442 types_pb2.DT_COMPLEX128_REF: complex128_ref, 443 types_pb2.DT_INT64_REF: int64_ref, 444 types_pb2.DT_UINT64_REF: uint64_ref, 445 types_pb2.DT_BOOL_REF: bool_ref, 446 types_pb2.DT_QINT8_REF: qint8_ref, 447 types_pb2.DT_QUINT8_REF: quint8_ref, 448 types_pb2.DT_QINT16_REF: qint16_ref, 449 types_pb2.DT_QUINT16_REF: quint16_ref, 450 types_pb2.DT_QINT32_REF: qint32_ref, 451 types_pb2.DT_BFLOAT16_REF: bfloat16_ref, 452 types_pb2.DT_RESOURCE_REF: resource_ref, 453 types_pb2.DT_VARIANT_REF: variant_ref, 454} 455 456# Standard mappings between types_pb2.DataType values and string names. 457_TYPE_TO_STRING = { 458 types_pb2.DT_HALF: "float16", 459 types_pb2.DT_FLOAT: "float32", 460 types_pb2.DT_DOUBLE: "float64", 461 types_pb2.DT_INT32: "int32", 462 types_pb2.DT_UINT8: "uint8", 463 types_pb2.DT_UINT16: "uint16", 464 types_pb2.DT_UINT32: "uint32", 465 types_pb2.DT_UINT64: "uint64", 466 types_pb2.DT_INT16: "int16", 467 types_pb2.DT_INT8: "int8", 468 types_pb2.DT_STRING: "string", 469 types_pb2.DT_COMPLEX64: "complex64", 470 types_pb2.DT_COMPLEX128: "complex128", 471 types_pb2.DT_INT64: "int64", 472 types_pb2.DT_BOOL: "bool", 473 types_pb2.DT_QINT8: "qint8", 474 types_pb2.DT_QUINT8: "quint8", 475 types_pb2.DT_QINT16: "qint16", 476 types_pb2.DT_QUINT16: "quint16", 477 types_pb2.DT_QINT32: "qint32", 478 types_pb2.DT_BFLOAT16: "bfloat16", 479 types_pb2.DT_RESOURCE: "resource", 480 types_pb2.DT_VARIANT: "variant", 481 types_pb2.DT_HALF_REF: "float16_ref", 482 types_pb2.DT_FLOAT_REF: "float32_ref", 483 types_pb2.DT_DOUBLE_REF: "float64_ref", 484 types_pb2.DT_INT32_REF: "int32_ref", 485 types_pb2.DT_UINT32_REF: "uint32_ref", 486 types_pb2.DT_UINT8_REF: "uint8_ref", 487 types_pb2.DT_UINT16_REF: "uint16_ref", 488 types_pb2.DT_INT16_REF: "int16_ref", 489 types_pb2.DT_INT8_REF: "int8_ref", 490 types_pb2.DT_STRING_REF: "string_ref", 491 types_pb2.DT_COMPLEX64_REF: "complex64_ref", 492 types_pb2.DT_COMPLEX128_REF: "complex128_ref", 493 types_pb2.DT_INT64_REF: "int64_ref", 494 types_pb2.DT_UINT64_REF: "uint64_ref", 495 types_pb2.DT_BOOL_REF: "bool_ref", 496 types_pb2.DT_QINT8_REF: "qint8_ref", 497 types_pb2.DT_QUINT8_REF: "quint8_ref", 498 types_pb2.DT_QINT16_REF: "qint16_ref", 499 types_pb2.DT_QUINT16_REF: "quint16_ref", 500 types_pb2.DT_QINT32_REF: "qint32_ref", 501 types_pb2.DT_BFLOAT16_REF: "bfloat16_ref", 502 types_pb2.DT_RESOURCE_REF: "resource_ref", 503 types_pb2.DT_VARIANT_REF: "variant_ref", 504} 505_STRING_TO_TF = { 506 value: _INTERN_TABLE[key] 507 for key, value in _TYPE_TO_STRING.items() 508} 509# Add non-canonical aliases. 510_STRING_TO_TF["half"] = float16 511_STRING_TO_TF["half_ref"] = float16_ref 512_STRING_TO_TF["float"] = float32 513_STRING_TO_TF["float_ref"] = float32_ref 514_STRING_TO_TF["double"] = float64 515_STRING_TO_TF["double_ref"] = float64_ref 516 517# Numpy representation for quantized dtypes. 518# 519# These are magic strings that are used in the swig wrapper to identify 520# quantized types. 521# TODO(mrry,keveman): Investigate Numpy type registration to replace this 522# hard-coding of names. 523_np_qint8 = np.dtype([("qint8", np.int8, 1)]) 524_np_quint8 = np.dtype([("quint8", np.uint8, 1)]) 525_np_qint16 = np.dtype([("qint16", np.int16, 1)]) 526_np_quint16 = np.dtype([("quint16", np.uint16, 1)]) 527_np_qint32 = np.dtype([("qint32", np.int32, 1)]) 528 529# _np_bfloat16 is defined by a module import. 530 531# Custom struct dtype for directly-fed ResourceHandles of supported type(s). 532np_resource = np.dtype([("resource", np.ubyte, 1)]) 533 534# Standard mappings between types_pb2.DataType values and numpy.dtypes. 535_NP_TO_TF = { 536 np.float16: float16, 537 np.float32: float32, 538 np.float64: float64, 539 np.int32: int32, 540 np.int64: int64, 541 np.uint8: uint8, 542 np.uint16: uint16, 543 np.uint32: uint32, 544 np.uint64: uint64, 545 np.int16: int16, 546 np.int8: int8, 547 np.complex64: complex64, 548 np.complex128: complex128, 549 np.object_: string, 550 np.string_: string, 551 np.unicode_: string, 552 np.bool_: bool, 553 _np_qint8: qint8, 554 _np_quint8: quint8, 555 _np_qint16: qint16, 556 _np_quint16: quint16, 557 _np_qint32: qint32, 558 _np_bfloat16: bfloat16, 559} 560 561# Map (some) NumPy platform dtypes to TF ones using their fixed-width 562# synonyms. Note that platform dtypes are not always simples aliases, 563# i.e. reference equality is not guaranteed. See e.g. numpy/numpy#9799. 564for pdt in [ 565 np.intc, 566 np.uintc, 567 np.int_, 568 np.uint, 569 np.longlong, 570 np.ulonglong, 571]: 572 if pdt not in _NP_TO_TF: 573 _NP_TO_TF[pdt] = next( 574 _NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype) 575 576_TF_TO_NP = { 577 types_pb2.DT_HALF: 578 np.float16, 579 types_pb2.DT_FLOAT: 580 np.float32, 581 types_pb2.DT_DOUBLE: 582 np.float64, 583 types_pb2.DT_INT32: 584 np.int32, 585 types_pb2.DT_UINT8: 586 np.uint8, 587 types_pb2.DT_UINT16: 588 np.uint16, 589 types_pb2.DT_UINT32: 590 np.uint32, 591 types_pb2.DT_UINT64: 592 np.uint64, 593 types_pb2.DT_INT16: 594 np.int16, 595 types_pb2.DT_INT8: 596 np.int8, 597 # NOTE(touts): For strings we use np.object as it supports variable length 598 # strings. 599 types_pb2.DT_STRING: 600 np.object, 601 types_pb2.DT_COMPLEX64: 602 np.complex64, 603 types_pb2.DT_COMPLEX128: 604 np.complex128, 605 types_pb2.DT_INT64: 606 np.int64, 607 types_pb2.DT_BOOL: 608 np.bool, 609 types_pb2.DT_QINT8: 610 _np_qint8, 611 types_pb2.DT_QUINT8: 612 _np_quint8, 613 types_pb2.DT_QINT16: 614 _np_qint16, 615 types_pb2.DT_QUINT16: 616 _np_quint16, 617 types_pb2.DT_QINT32: 618 _np_qint32, 619 types_pb2.DT_BFLOAT16: 620 _np_bfloat16, 621 622 # Ref types 623 types_pb2.DT_HALF_REF: 624 np.float16, 625 types_pb2.DT_FLOAT_REF: 626 np.float32, 627 types_pb2.DT_DOUBLE_REF: 628 np.float64, 629 types_pb2.DT_INT32_REF: 630 np.int32, 631 types_pb2.DT_UINT32_REF: 632 np.uint32, 633 types_pb2.DT_UINT8_REF: 634 np.uint8, 635 types_pb2.DT_UINT16_REF: 636 np.uint16, 637 types_pb2.DT_INT16_REF: 638 np.int16, 639 types_pb2.DT_INT8_REF: 640 np.int8, 641 types_pb2.DT_STRING_REF: 642 np.object, 643 types_pb2.DT_COMPLEX64_REF: 644 np.complex64, 645 types_pb2.DT_COMPLEX128_REF: 646 np.complex128, 647 types_pb2.DT_INT64_REF: 648 np.int64, 649 types_pb2.DT_UINT64_REF: 650 np.uint64, 651 types_pb2.DT_BOOL_REF: 652 np.bool, 653 types_pb2.DT_QINT8_REF: 654 _np_qint8, 655 types_pb2.DT_QUINT8_REF: 656 _np_quint8, 657 types_pb2.DT_QINT16_REF: 658 _np_qint16, 659 types_pb2.DT_QUINT16_REF: 660 _np_quint16, 661 types_pb2.DT_QINT32_REF: 662 _np_qint32, 663 types_pb2.DT_BFLOAT16_REF: 664 _np_bfloat16, 665} 666 667_QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32]) 668_QUANTIZED_DTYPES_REF = frozenset( 669 [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref]) 670QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF) 671tf_export( 672 "dtypes.QUANTIZED_DTYPES", 673 v1=["dtypes.QUANTIZED_DTYPES", "QUANTIZED_DTYPES"]).export_constant( 674 __name__, "QUANTIZED_DTYPES") 675 676_PYTHON_TO_TF = { 677 builtins.float: float32, 678 builtins.bool: bool, 679 builtins.object: string 680} 681 682_ANY_TO_TF = {} 683_ANY_TO_TF.update(_INTERN_TABLE) 684_ANY_TO_TF.update(_STRING_TO_TF) 685_ANY_TO_TF.update(_PYTHON_TO_TF) 686_ANY_TO_TF.update(_NP_TO_TF) 687 688# Ensure no collisions. 689assert len(_ANY_TO_TF) == sum(len(d) for d in [ 690 _INTERN_TABLE, 691 _STRING_TO_TF, 692 _PYTHON_TO_TF, 693 _NP_TO_TF 694]) 695 696 697@tf_export("dtypes.as_dtype", "as_dtype") 698def as_dtype(type_value): 699 """Converts the given `type_value` to a `DType`. 700 701 Args: 702 type_value: A value that can be converted to a `tf.DType` object. This may 703 currently be a `tf.DType` object, a [`DataType` 704 enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto), 705 a string type name, or a `numpy.dtype`. 706 707 Returns: 708 A `DType` corresponding to `type_value`. 709 710 Raises: 711 TypeError: If `type_value` cannot be converted to a `DType`. 712 """ 713 if isinstance(type_value, DType): 714 return type_value 715 716 if isinstance(type_value, np.dtype): 717 try: 718 return _NP_TO_TF[type_value.type] 719 except KeyError: 720 pass 721 722 try: 723 return _ANY_TO_TF[type_value] 724 except KeyError: 725 pass 726 727 raise TypeError( 728 "Cannot convert value %r to a TensorFlow DType." % (type_value,)) 729