1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library of dtypes (Tensor element types).""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.core.framework import types_pb2 23from tensorflow.python import pywrap_tensorflow 24from tensorflow.python.util.tf_export import tf_export 25 26_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type() 27 28 29@tf_export("DType") 30class DType(object): 31 """Represents the type of the elements in a `Tensor`. 32 33 The following `DType` objects are defined: 34 35 * `tf.float16`: 16-bit half-precision floating-point. 36 * `tf.float32`: 32-bit single-precision floating-point. 37 * `tf.float64`: 64-bit double-precision floating-point. 38 * `tf.bfloat16`: 16-bit truncated floating-point. 39 * `tf.complex64`: 64-bit single-precision complex. 40 * `tf.complex128`: 128-bit double-precision complex. 41 * `tf.int8`: 8-bit signed integer. 42 * `tf.uint8`: 8-bit unsigned integer. 43 * `tf.uint16`: 16-bit unsigned integer. 44 * `tf.uint32`: 32-bit unsigned integer. 45 * `tf.uint64`: 64-bit unsigned integer. 46 * `tf.int16`: 16-bit signed integer. 47 * `tf.int32`: 32-bit signed integer. 48 * `tf.int64`: 64-bit signed integer. 49 * `tf.bool`: Boolean. 50 * `tf.string`: String. 51 * `tf.qint8`: Quantized 8-bit signed integer. 52 * `tf.quint8`: Quantized 8-bit unsigned integer. 53 * `tf.qint16`: Quantized 16-bit signed integer. 54 * `tf.quint16`: Quantized 16-bit unsigned integer. 55 * `tf.qint32`: Quantized 32-bit signed integer. 56 * `tf.resource`: Handle to a mutable resource. 57 * `tf.variant`: Values of arbitrary types. 58 59 In addition, variants of these types with the `_ref` suffix are 60 defined for reference-typed tensors. 61 62 The `tf.as_dtype()` function converts numpy types and string type 63 names to a `DType` object. 64 """ 65 66 def __init__(self, type_enum): 67 """Creates a new `DataType`. 68 69 NOTE(mrry): In normal circumstances, you should not need to 70 construct a `DataType` object directly. Instead, use the 71 `tf.as_dtype()` function. 72 73 Args: 74 type_enum: A `types_pb2.DataType` enum value. 75 76 Raises: 77 TypeError: If `type_enum` is not a value `types_pb2.DataType`. 78 79 """ 80 # TODO(mrry): Make the necessary changes (using __new__) to ensure 81 # that calling this returns one of the interned values. 82 type_enum = int(type_enum) 83 if (type_enum not in types_pb2.DataType.values() or 84 type_enum == types_pb2.DT_INVALID): 85 raise TypeError( 86 "type_enum is not a valid types_pb2.DataType: %s" % type_enum) 87 self._type_enum = type_enum 88 89 @property 90 def _is_ref_dtype(self): 91 """Returns `True` if this `DType` represents a reference type.""" 92 return self._type_enum > 100 93 94 @property 95 def _as_ref(self): 96 """Returns a reference `DType` based on this `DType`.""" 97 if self._is_ref_dtype: 98 return self 99 else: 100 return _INTERN_TABLE[self._type_enum + 100] 101 102 @property 103 def base_dtype(self): 104 """Returns a non-reference `DType` based on this `DType`.""" 105 if self._is_ref_dtype: 106 return _INTERN_TABLE[self._type_enum - 100] 107 else: 108 return self 109 110 @property 111 def real_dtype(self): 112 """Returns the dtype correspond to this dtype's real part.""" 113 base = self.base_dtype 114 if base == complex64: 115 return float32 116 elif base == complex128: 117 return float64 118 else: 119 return self 120 121 @property 122 def is_numpy_compatible(self): 123 numpy_incompatible = [ 124 types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE, 125 types_pb2.DT_RESOURCE_REF 126 ] 127 return self._type_enum not in numpy_incompatible 128 129 @property 130 def as_numpy_dtype(self): 131 """Returns a `numpy.dtype` based on this `DType`.""" 132 return _TF_TO_NP[self._type_enum] 133 134 @property 135 def as_datatype_enum(self): 136 """Returns a `types_pb2.DataType` enum value based on this `DType`.""" 137 return self._type_enum 138 139 @property 140 def is_bool(self): 141 """Returns whether this is a boolean data type""" 142 return self.base_dtype == bool 143 144 @property 145 def is_integer(self): 146 """Returns whether this is a (non-quantized) integer type.""" 147 return (self.is_numpy_compatible and not self.is_quantized and 148 np.issubdtype(self.as_numpy_dtype, np.integer)) 149 150 @property 151 def is_floating(self): 152 """Returns whether this is a (non-quantized, real) floating point type.""" 153 return ((self.is_numpy_compatible and 154 np.issubdtype(self.as_numpy_dtype, np.floating)) or 155 self.base_dtype == bfloat16) 156 157 @property 158 def is_complex(self): 159 """Returns whether this is a complex floating point type.""" 160 return self.base_dtype in (complex64, complex128) 161 162 @property 163 def is_quantized(self): 164 """Returns whether this is a quantized data type.""" 165 return self.base_dtype in [qint8, quint8, qint16, quint16, qint32] 166 167 @property 168 def is_unsigned(self): 169 """Returns whether this type is unsigned. 170 171 Non-numeric, unordered, and quantized types are not considered unsigned, and 172 this function returns `False`. 173 174 Returns: 175 Whether a `DType` is unsigned. 176 """ 177 try: 178 return self.min == 0 179 except TypeError: 180 return False 181 182 @property 183 def min(self): 184 """Returns the minimum representable value in this data type. 185 186 Raises: 187 TypeError: if this is a non-numeric, unordered, or quantized type. 188 189 """ 190 if (self.is_quantized or 191 self.base_dtype in (bool, string, complex64, complex128)): 192 raise TypeError("Cannot find minimum value of %s." % self) 193 194 # there is no simple way to get the min value of a dtype, we have to check 195 # float and int types separately 196 try: 197 return np.finfo(self.as_numpy_dtype()).min 198 except: # bare except as possible raises by finfo not documented 199 try: 200 return np.iinfo(self.as_numpy_dtype()).min 201 except: 202 if self.base_dtype == bfloat16: 203 return _np_bfloat16(float.fromhex("-0x1.FEp127")) 204 raise TypeError("Cannot find minimum value of %s." % self) 205 206 @property 207 def max(self): 208 """Returns the maximum representable value in this data type. 209 210 Raises: 211 TypeError: if this is a non-numeric, unordered, or quantized type. 212 213 """ 214 if (self.is_quantized or 215 self.base_dtype in (bool, string, complex64, complex128)): 216 raise TypeError("Cannot find maximum value of %s." % self) 217 218 # there is no simple way to get the max value of a dtype, we have to check 219 # float and int types separately 220 try: 221 return np.finfo(self.as_numpy_dtype()).max 222 except: # bare except as possible raises by finfo not documented 223 try: 224 return np.iinfo(self.as_numpy_dtype()).max 225 except: 226 if self.base_dtype == bfloat16: 227 return _np_bfloat16(float.fromhex("0x1.FEp127")) 228 raise TypeError("Cannot find maximum value of %s." % self) 229 230 @property 231 def limits(self, clip_negative=True): 232 """Return intensity limits, i.e. (min, max) tuple, of the dtype. 233 Args: 234 clip_negative : bool, optional 235 If True, clip the negative range (i.e. return 0 for min intensity) 236 even if the image dtype allows negative values. 237 Returns 238 min, max : tuple 239 Lower and upper intensity limits. 240 """ 241 min, max = dtype_range[self.as_numpy_dtype] # pylint: disable=redefined-builtin 242 if clip_negative: 243 min = 0 # pylint: disable=redefined-builtin 244 return min, max 245 246 def is_compatible_with(self, other): 247 """Returns True if the `other` DType will be converted to this DType. 248 249 The conversion rules are as follows: 250 251 ```python 252 DType(T) .is_compatible_with(DType(T)) == True 253 DType(T) .is_compatible_with(DType(T).as_ref) == True 254 DType(T).as_ref.is_compatible_with(DType(T)) == False 255 DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True 256 ``` 257 258 Args: 259 other: A `DType` (or object that may be converted to a `DType`). 260 261 Returns: 262 True if a Tensor of the `other` `DType` will be implicitly converted to 263 this `DType`. 264 """ 265 other = as_dtype(other) 266 return self._type_enum in (other.as_datatype_enum, 267 other.base_dtype.as_datatype_enum) 268 269 def __eq__(self, other): 270 """Returns True iff this DType refers to the same type as `other`.""" 271 if other is None: 272 return False 273 try: 274 dtype = as_dtype(other).as_datatype_enum 275 return self._type_enum == dtype # pylint: disable=protected-access 276 except TypeError: 277 return False 278 279 def __ne__(self, other): 280 """Returns True iff self != other.""" 281 return not self.__eq__(other) 282 283 @property 284 def name(self): 285 """Returns the string name for this `DType`.""" 286 return _TYPE_TO_STRING[self._type_enum] 287 288 def __int__(self): 289 return self._type_enum 290 291 def __str__(self): 292 return "<dtype: %r>" % self.name 293 294 def __repr__(self): 295 return "tf." + self.name 296 297 def __hash__(self): 298 return self._type_enum 299 300 @property 301 def size(self): 302 if (self._type_enum == types_pb2.DT_VARIANT or 303 self._type_enum == types_pb2.DT_RESOURCE): 304 return 1 305 return np.dtype(self.as_numpy_dtype).itemsize 306 307 308# Define data type range of numpy dtype 309dtype_range = { 310 np.bool_: (False, True), 311 np.bool8: (False, True), 312 np.uint8: (0, 255), 313 np.uint16: (0, 65535), 314 np.int8: (-128, 127), 315 np.int16: (-32768, 32767), 316 np.int64: (-2**63, 2**63 - 1), 317 np.uint64: (0, 2**64 - 1), 318 np.int32: (-2**31, 2**31 - 1), 319 np.uint32: (0, 2**32 - 1), 320 np.float32: (-1, 1), 321 np.float64: (-1, 1) 322} 323 324# Define standard wrappers for the types_pb2.DataType enum. 325resource = DType(types_pb2.DT_RESOURCE) 326tf_export("resource").export_constant(__name__, "resource") 327variant = DType(types_pb2.DT_VARIANT) 328tf_export("variant").export_constant(__name__, "variant") 329float16 = DType(types_pb2.DT_HALF) 330tf_export("float16").export_constant(__name__, "float16") 331half = float16 332tf_export("half").export_constant(__name__, "half") 333float32 = DType(types_pb2.DT_FLOAT) 334tf_export("float32").export_constant(__name__, "float32") 335float64 = DType(types_pb2.DT_DOUBLE) 336tf_export("float64").export_constant(__name__, "float64") 337double = float64 338tf_export("double").export_constant(__name__, "double") 339int32 = DType(types_pb2.DT_INT32) 340tf_export("int32").export_constant(__name__, "int32") 341uint8 = DType(types_pb2.DT_UINT8) 342tf_export("uint8").export_constant(__name__, "uint8") 343uint16 = DType(types_pb2.DT_UINT16) 344tf_export("uint16").export_constant(__name__, "uint16") 345uint32 = DType(types_pb2.DT_UINT32) 346uint64 = DType(types_pb2.DT_UINT64) 347int16 = DType(types_pb2.DT_INT16) 348tf_export("int16").export_constant(__name__, "int16") 349int8 = DType(types_pb2.DT_INT8) 350tf_export("int8").export_constant(__name__, "int8") 351string = DType(types_pb2.DT_STRING) 352tf_export("string").export_constant(__name__, "string") 353complex64 = DType(types_pb2.DT_COMPLEX64) 354tf_export("complex64").export_constant(__name__, "complex64") 355complex128 = DType(types_pb2.DT_COMPLEX128) 356tf_export("complex128").export_constant(__name__, "complex128") 357int64 = DType(types_pb2.DT_INT64) 358tf_export("int64").export_constant(__name__, "int64") 359bool = DType(types_pb2.DT_BOOL) # pylint: disable=redefined-builtin 360tf_export("bool").export_constant(__name__, "bool") 361qint8 = DType(types_pb2.DT_QINT8) 362tf_export("qint8").export_constant(__name__, "qint8") 363quint8 = DType(types_pb2.DT_QUINT8) 364tf_export("quint8").export_constant(__name__, "quint8") 365qint16 = DType(types_pb2.DT_QINT16) 366tf_export("qint16").export_constant(__name__, "qint16") 367quint16 = DType(types_pb2.DT_QUINT16) 368tf_export("quint16").export_constant(__name__, "quint16") 369qint32 = DType(types_pb2.DT_QINT32) 370tf_export("qint32").export_constant(__name__, "qint32") 371resource_ref = DType(types_pb2.DT_RESOURCE_REF) 372variant_ref = DType(types_pb2.DT_VARIANT_REF) 373bfloat16 = DType(types_pb2.DT_BFLOAT16) 374tf_export("bfloat16").export_constant(__name__, "bfloat16") 375float16_ref = DType(types_pb2.DT_HALF_REF) 376half_ref = float16_ref 377float32_ref = DType(types_pb2.DT_FLOAT_REF) 378float64_ref = DType(types_pb2.DT_DOUBLE_REF) 379double_ref = float64_ref 380int32_ref = DType(types_pb2.DT_INT32_REF) 381uint32_ref = DType(types_pb2.DT_UINT32_REF) 382uint8_ref = DType(types_pb2.DT_UINT8_REF) 383uint16_ref = DType(types_pb2.DT_UINT16_REF) 384int16_ref = DType(types_pb2.DT_INT16_REF) 385int8_ref = DType(types_pb2.DT_INT8_REF) 386string_ref = DType(types_pb2.DT_STRING_REF) 387complex64_ref = DType(types_pb2.DT_COMPLEX64_REF) 388complex128_ref = DType(types_pb2.DT_COMPLEX128_REF) 389int64_ref = DType(types_pb2.DT_INT64_REF) 390uint64_ref = DType(types_pb2.DT_UINT64_REF) 391bool_ref = DType(types_pb2.DT_BOOL_REF) 392qint8_ref = DType(types_pb2.DT_QINT8_REF) 393quint8_ref = DType(types_pb2.DT_QUINT8_REF) 394qint16_ref = DType(types_pb2.DT_QINT16_REF) 395quint16_ref = DType(types_pb2.DT_QUINT16_REF) 396qint32_ref = DType(types_pb2.DT_QINT32_REF) 397bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF) 398 399# Maintain an intern table so that we don't have to create a large 400# number of small objects. 401_INTERN_TABLE = { 402 types_pb2.DT_HALF: float16, 403 types_pb2.DT_FLOAT: float32, 404 types_pb2.DT_DOUBLE: float64, 405 types_pb2.DT_INT32: int32, 406 types_pb2.DT_UINT8: uint8, 407 types_pb2.DT_UINT16: uint16, 408 types_pb2.DT_UINT32: uint32, 409 types_pb2.DT_UINT64: uint64, 410 types_pb2.DT_INT16: int16, 411 types_pb2.DT_INT8: int8, 412 types_pb2.DT_STRING: string, 413 types_pb2.DT_COMPLEX64: complex64, 414 types_pb2.DT_COMPLEX128: complex128, 415 types_pb2.DT_INT64: int64, 416 types_pb2.DT_BOOL: bool, 417 types_pb2.DT_QINT8: qint8, 418 types_pb2.DT_QUINT8: quint8, 419 types_pb2.DT_QINT16: qint16, 420 types_pb2.DT_QUINT16: quint16, 421 types_pb2.DT_QINT32: qint32, 422 types_pb2.DT_BFLOAT16: bfloat16, 423 types_pb2.DT_RESOURCE: resource, 424 types_pb2.DT_VARIANT: variant, 425 types_pb2.DT_HALF_REF: float16_ref, 426 types_pb2.DT_FLOAT_REF: float32_ref, 427 types_pb2.DT_DOUBLE_REF: float64_ref, 428 types_pb2.DT_INT32_REF: int32_ref, 429 types_pb2.DT_UINT32_REF: uint32_ref, 430 types_pb2.DT_UINT8_REF: uint8_ref, 431 types_pb2.DT_UINT16_REF: uint16_ref, 432 types_pb2.DT_INT16_REF: int16_ref, 433 types_pb2.DT_INT8_REF: int8_ref, 434 types_pb2.DT_STRING_REF: string_ref, 435 types_pb2.DT_COMPLEX64_REF: complex64_ref, 436 types_pb2.DT_COMPLEX128_REF: complex128_ref, 437 types_pb2.DT_INT64_REF: int64_ref, 438 types_pb2.DT_UINT64_REF: uint64_ref, 439 types_pb2.DT_BOOL_REF: bool_ref, 440 types_pb2.DT_QINT8_REF: qint8_ref, 441 types_pb2.DT_QUINT8_REF: quint8_ref, 442 types_pb2.DT_QINT16_REF: qint16_ref, 443 types_pb2.DT_QUINT16_REF: quint16_ref, 444 types_pb2.DT_QINT32_REF: qint32_ref, 445 types_pb2.DT_BFLOAT16_REF: bfloat16_ref, 446 types_pb2.DT_RESOURCE_REF: resource_ref, 447 types_pb2.DT_VARIANT_REF: variant_ref, 448} 449 450# Standard mappings between types_pb2.DataType values and string names. 451_TYPE_TO_STRING = { 452 types_pb2.DT_HALF: "float16", 453 types_pb2.DT_FLOAT: "float32", 454 types_pb2.DT_DOUBLE: "float64", 455 types_pb2.DT_INT32: "int32", 456 types_pb2.DT_UINT8: "uint8", 457 types_pb2.DT_UINT16: "uint16", 458 types_pb2.DT_UINT32: "uint32", 459 types_pb2.DT_UINT64: "uint64", 460 types_pb2.DT_INT16: "int16", 461 types_pb2.DT_INT8: "int8", 462 types_pb2.DT_STRING: "string", 463 types_pb2.DT_COMPLEX64: "complex64", 464 types_pb2.DT_COMPLEX128: "complex128", 465 types_pb2.DT_INT64: "int64", 466 types_pb2.DT_BOOL: "bool", 467 types_pb2.DT_QINT8: "qint8", 468 types_pb2.DT_QUINT8: "quint8", 469 types_pb2.DT_QINT16: "qint16", 470 types_pb2.DT_QUINT16: "quint16", 471 types_pb2.DT_QINT32: "qint32", 472 types_pb2.DT_BFLOAT16: "bfloat16", 473 types_pb2.DT_RESOURCE: "resource", 474 types_pb2.DT_VARIANT: "variant", 475 types_pb2.DT_HALF_REF: "float16_ref", 476 types_pb2.DT_FLOAT_REF: "float32_ref", 477 types_pb2.DT_DOUBLE_REF: "float64_ref", 478 types_pb2.DT_INT32_REF: "int32_ref", 479 types_pb2.DT_UINT32_REF: "uint32_ref", 480 types_pb2.DT_UINT8_REF: "uint8_ref", 481 types_pb2.DT_UINT16_REF: "uint16_ref", 482 types_pb2.DT_INT16_REF: "int16_ref", 483 types_pb2.DT_INT8_REF: "int8_ref", 484 types_pb2.DT_STRING_REF: "string_ref", 485 types_pb2.DT_COMPLEX64_REF: "complex64_ref", 486 types_pb2.DT_COMPLEX128_REF: "complex128_ref", 487 types_pb2.DT_INT64_REF: "int64_ref", 488 types_pb2.DT_UINT64_REF: "uint64_ref", 489 types_pb2.DT_BOOL_REF: "bool_ref", 490 types_pb2.DT_QINT8_REF: "qint8_ref", 491 types_pb2.DT_QUINT8_REF: "quint8_ref", 492 types_pb2.DT_QINT16_REF: "qint16_ref", 493 types_pb2.DT_QUINT16_REF: "quint16_ref", 494 types_pb2.DT_QINT32_REF: "qint32_ref", 495 types_pb2.DT_BFLOAT16_REF: "bfloat16_ref", 496 types_pb2.DT_RESOURCE_REF: "resource_ref", 497 types_pb2.DT_VARIANT_REF: "variant_ref", 498} 499_STRING_TO_TF = { 500 value: _INTERN_TABLE[key] 501 for key, value in _TYPE_TO_STRING.items() 502} 503# Add non-canonical aliases. 504_STRING_TO_TF["half"] = float16 505_STRING_TO_TF["half_ref"] = float16_ref 506_STRING_TO_TF["float"] = float32 507_STRING_TO_TF["float_ref"] = float32_ref 508_STRING_TO_TF["double"] = float64 509_STRING_TO_TF["double_ref"] = float64_ref 510 511# Numpy representation for quantized dtypes. 512# 513# These are magic strings that are used in the swig wrapper to identify 514# quantized types. 515# TODO(mrry,keveman): Investigate Numpy type registration to replace this 516# hard-coding of names. 517_np_qint8 = np.dtype([("qint8", np.int8, 1)]) 518_np_quint8 = np.dtype([("quint8", np.uint8, 1)]) 519_np_qint16 = np.dtype([("qint16", np.int16, 1)]) 520_np_quint16 = np.dtype([("quint16", np.uint16, 1)]) 521_np_qint32 = np.dtype([("qint32", np.int32, 1)]) 522 523# _np_bfloat16 is defined by a module import. 524 525# Custom struct dtype for directly-fed ResourceHandles of supported type(s). 526np_resource = np.dtype([("resource", np.ubyte, 1)]) 527 528# Standard mappings between types_pb2.DataType values and numpy.dtypes. 529_NP_TO_TF = frozenset([ 530 (np.float16, float16), 531 (np.float32, float32), 532 (np.float64, float64), 533 (np.int32, int32), 534 (np.int64, int64), 535 (np.uint8, uint8), 536 (np.uint16, uint16), 537 (np.uint32, uint32), 538 (np.uint64, uint64), 539 (np.int16, int16), 540 (np.int8, int8), 541 (np.complex64, complex64), 542 (np.complex128, complex128), 543 (np.object, string), 544 (np.bool, bool), 545 (_np_qint8, qint8), 546 (_np_quint8, quint8), 547 (_np_qint16, qint16), 548 (_np_quint16, quint16), 549 (_np_qint32, qint32), 550 (_np_bfloat16, bfloat16), 551]) 552_TF_TO_NP = { 553 types_pb2.DT_HALF: 554 np.float16, 555 types_pb2.DT_FLOAT: 556 np.float32, 557 types_pb2.DT_DOUBLE: 558 np.float64, 559 types_pb2.DT_INT32: 560 np.int32, 561 types_pb2.DT_UINT8: 562 np.uint8, 563 types_pb2.DT_UINT16: 564 np.uint16, 565 types_pb2.DT_UINT32: 566 np.uint32, 567 types_pb2.DT_UINT64: 568 np.uint64, 569 types_pb2.DT_INT16: 570 np.int16, 571 types_pb2.DT_INT8: 572 np.int8, 573 # NOTE(touts): For strings we use np.object as it supports variable length 574 # strings. 575 types_pb2.DT_STRING: 576 np.object, 577 types_pb2.DT_COMPLEX64: 578 np.complex64, 579 types_pb2.DT_COMPLEX128: 580 np.complex128, 581 types_pb2.DT_INT64: 582 np.int64, 583 types_pb2.DT_BOOL: 584 np.bool, 585 types_pb2.DT_QINT8: 586 _np_qint8, 587 types_pb2.DT_QUINT8: 588 _np_quint8, 589 types_pb2.DT_QINT16: 590 _np_qint16, 591 types_pb2.DT_QUINT16: 592 _np_quint16, 593 types_pb2.DT_QINT32: 594 _np_qint32, 595 types_pb2.DT_BFLOAT16: 596 _np_bfloat16, 597 598 # Ref types 599 types_pb2.DT_HALF_REF: 600 np.float16, 601 types_pb2.DT_FLOAT_REF: 602 np.float32, 603 types_pb2.DT_DOUBLE_REF: 604 np.float64, 605 types_pb2.DT_INT32_REF: 606 np.int32, 607 types_pb2.DT_UINT32_REF: 608 np.uint32, 609 types_pb2.DT_UINT8_REF: 610 np.uint8, 611 types_pb2.DT_UINT16_REF: 612 np.uint16, 613 types_pb2.DT_INT16_REF: 614 np.int16, 615 types_pb2.DT_INT8_REF: 616 np.int8, 617 types_pb2.DT_STRING_REF: 618 np.object, 619 types_pb2.DT_COMPLEX64_REF: 620 np.complex64, 621 types_pb2.DT_COMPLEX128_REF: 622 np.complex128, 623 types_pb2.DT_INT64_REF: 624 np.int64, 625 types_pb2.DT_UINT64_REF: 626 np.uint64, 627 types_pb2.DT_BOOL_REF: 628 np.bool, 629 types_pb2.DT_QINT8_REF: 630 _np_qint8, 631 types_pb2.DT_QUINT8_REF: 632 _np_quint8, 633 types_pb2.DT_QINT16_REF: 634 _np_qint16, 635 types_pb2.DT_QUINT16_REF: 636 _np_quint16, 637 types_pb2.DT_QINT32_REF: 638 _np_qint32, 639 types_pb2.DT_BFLOAT16_REF: 640 _np_bfloat16, 641} 642 643QUANTIZED_DTYPES = frozenset([ 644 qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref, 645 quint16_ref, qint32_ref 646]) 647tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES") 648 649 650@tf_export("as_dtype") 651def as_dtype(type_value): 652 """Converts the given `type_value` to a `DType`. 653 654 Args: 655 type_value: A value that can be converted to a `tf.DType` 656 object. This may currently be a `tf.DType` object, a 657 [`DataType` 658 enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto), 659 a string type name, or a `numpy.dtype`. 660 661 Returns: 662 A `DType` corresponding to `type_value`. 663 664 Raises: 665 TypeError: If `type_value` cannot be converted to a `DType`. 666 """ 667 if isinstance(type_value, DType): 668 return type_value 669 670 try: 671 return _INTERN_TABLE[type_value] 672 except KeyError: 673 pass 674 675 try: 676 return _STRING_TO_TF[type_value] 677 except KeyError: 678 pass 679 680 if isinstance(type_value, np.dtype): 681 # The numpy dtype for strings is variable length. We can not compare 682 # dtype with a single constant (np.string does not exist) to decide 683 # dtype is a "string" type. We need to compare the dtype.type to be 684 # sure it's a string type. 685 if type_value.type == np.string_ or type_value.type == np.unicode_: 686 return string 687 688 for key, val in _NP_TO_TF: 689 try: 690 if key == type_value: 691 return val 692 except TypeError as e: 693 raise TypeError("Cannot convert {} to a dtype. {}".format(type_value, e)) 694 695 raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value) 696