1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library of dtypes (Tensor element types).""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21from six.moves import builtins 22 23from tensorflow.core.framework import types_pb2 24# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid 25# protobuf errors where a file is defined twice on MacOS. 26# pylint: disable=invalid-import-order,g-bad-import-order 27from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 28from tensorflow.python import _pywrap_bfloat16 29from tensorflow.python import _dtypes 30from tensorflow.python.util.tf_export import tf_export 31 32_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type() 33 34 35# pylint: disable=slots-on-old-class 36@tf_export("dtypes.DType", "DType") 37class DType(_dtypes.DType): 38 """Represents the type of the elements in a `Tensor`. 39 40 The following `DType` objects are defined: 41 42 * `tf.float16`: 16-bit half-precision floating-point. 43 * `tf.float32`: 32-bit single-precision floating-point. 44 * `tf.float64`: 64-bit double-precision floating-point. 45 * `tf.bfloat16`: 16-bit truncated floating-point. 46 * `tf.complex64`: 64-bit single-precision complex. 47 * `tf.complex128`: 128-bit double-precision complex. 48 * `tf.int8`: 8-bit signed integer. 49 * `tf.uint8`: 8-bit unsigned integer. 50 * `tf.uint16`: 16-bit unsigned integer. 51 * `tf.uint32`: 32-bit unsigned integer. 52 * `tf.uint64`: 64-bit unsigned integer. 53 * `tf.int16`: 16-bit signed integer. 54 * `tf.int32`: 32-bit signed integer. 55 * `tf.int64`: 64-bit signed integer. 56 * `tf.bool`: Boolean. 57 * `tf.string`: String. 58 * `tf.qint8`: Quantized 8-bit signed integer. 59 * `tf.quint8`: Quantized 8-bit unsigned integer. 60 * `tf.qint16`: Quantized 16-bit signed integer. 61 * `tf.quint16`: Quantized 16-bit unsigned integer. 62 * `tf.qint32`: Quantized 32-bit signed integer. 63 * `tf.resource`: Handle to a mutable resource. 64 * `tf.variant`: Values of arbitrary types. 65 66 The `tf.as_dtype()` function converts numpy types and string type 67 names to a `DType` object. 68 """ 69 __slots__ = () 70 71 @property 72 def _is_ref_dtype(self): 73 """Returns `True` if this `DType` represents a reference type.""" 74 return self._type_enum > 100 75 76 @property 77 def _as_ref(self): 78 """Returns a reference `DType` based on this `DType`.""" 79 if self._is_ref_dtype: 80 return self 81 else: 82 return _INTERN_TABLE[self._type_enum + 100] 83 84 @property 85 def base_dtype(self): 86 """Returns a non-reference `DType` based on this `DType`.""" 87 if self._is_ref_dtype: 88 return _INTERN_TABLE[self._type_enum - 100] 89 else: 90 return self 91 92 @property 93 def real_dtype(self): 94 """Returns the `DType` corresponding to this `DType`'s real part.""" 95 base = self.base_dtype 96 if base == complex64: 97 return float32 98 elif base == complex128: 99 return float64 100 else: 101 return self 102 103 @property 104 def as_numpy_dtype(self): 105 """Returns a Python `type` object based on this `DType`.""" 106 return _TF_TO_NP[self._type_enum] 107 108 @property 109 def min(self): 110 """Returns the minimum representable value in this data type. 111 112 Raises: 113 TypeError: if this is a non-numeric, unordered, or quantized type. 114 115 """ 116 if (self.is_quantized or 117 self.base_dtype in (bool, string, complex64, complex128)): 118 raise TypeError("Cannot find minimum value of %s." % self) 119 120 # there is no simple way to get the min value of a dtype, we have to check 121 # float and int types separately 122 try: 123 return np.finfo(self.as_numpy_dtype).min 124 except: # bare except as possible raises by finfo not documented 125 try: 126 return np.iinfo(self.as_numpy_dtype).min 127 except: 128 if self.base_dtype == bfloat16: 129 return _np_bfloat16(float.fromhex("-0x1.FEp127")) 130 raise TypeError("Cannot find minimum value of %s." % self) 131 132 @property 133 def max(self): 134 """Returns the maximum representable value in this data type. 135 136 Raises: 137 TypeError: if this is a non-numeric, unordered, or quantized type. 138 139 """ 140 if (self.is_quantized or 141 self.base_dtype in (bool, string, complex64, complex128)): 142 raise TypeError("Cannot find maximum value of %s." % self) 143 144 # there is no simple way to get the max value of a dtype, we have to check 145 # float and int types separately 146 try: 147 return np.finfo(self.as_numpy_dtype).max 148 except: # bare except as possible raises by finfo not documented 149 try: 150 return np.iinfo(self.as_numpy_dtype).max 151 except: 152 if self.base_dtype == bfloat16: 153 return _np_bfloat16(float.fromhex("0x1.FEp127")) 154 raise TypeError("Cannot find maximum value of %s." % self) 155 156 @property 157 def limits(self, clip_negative=True): 158 """Return intensity limits, i.e. 159 160 (min, max) tuple, of the dtype. 161 Args: 162 clip_negative : bool, optional If True, clip the negative range (i.e. 163 return 0 for min intensity) even if the image dtype allows negative 164 values. Returns 165 min, max : tuple Lower and upper intensity limits. 166 """ 167 min, max = dtype_range[self.as_numpy_dtype] # pylint: disable=redefined-builtin 168 if clip_negative: 169 min = 0 # pylint: disable=redefined-builtin 170 return min, max 171 172 def is_compatible_with(self, other): 173 """Returns True if the `other` DType will be converted to this DType. 174 175 The conversion rules are as follows: 176 177 ```python 178 DType(T) .is_compatible_with(DType(T)) == True 179 ``` 180 181 Args: 182 other: A `DType` (or object that may be converted to a `DType`). 183 184 Returns: 185 True if a Tensor of the `other` `DType` will be implicitly converted to 186 this `DType`. 187 """ 188 other = as_dtype(other) 189 return self._type_enum in (other.as_datatype_enum, 190 other.base_dtype.as_datatype_enum) 191 192 def __eq__(self, other): 193 """Returns True iff this DType refers to the same type as `other`.""" 194 if other is None: 195 return False 196 197 if type(other) != DType: # pylint: disable=unidiomatic-typecheck 198 try: 199 other = as_dtype(other) 200 except TypeError: 201 return False 202 203 return self._type_enum == other._type_enum # pylint: disable=protected-access 204 205 def __ne__(self, other): 206 """Returns True iff self != other.""" 207 return not self.__eq__(other) 208 209 # "If a class that overrides __eq__() needs to retain the implementation 210 # of __hash__() from a parent class, the interpreter must be told this 211 # explicitly by setting __hash__ = <ParentClass>.__hash__." 212 # TODO(slebedev): Remove once __eq__ and __ne__ are implemented in C++. 213 __hash__ = _dtypes.DType.__hash__ 214 215 def __reduce__(self): 216 return as_dtype, (self.name,) 217# pylint: enable=slots-on-old-class 218 219 220# Define data type range of numpy dtype 221dtype_range = { 222 np.bool_: (False, True), 223 np.bool8: (False, True), 224 np.uint8: (0, 255), 225 np.uint16: (0, 65535), 226 np.int8: (-128, 127), 227 np.int16: (-32768, 32767), 228 np.int64: (-2**63, 2**63 - 1), 229 np.uint64: (0, 2**64 - 1), 230 np.int32: (-2**31, 2**31 - 1), 231 np.uint32: (0, 2**32 - 1), 232 np.float32: (-1, 1), 233 np.float64: (-1, 1) 234} 235 236# Define standard wrappers for the types_pb2.DataType enum. 237resource = DType(types_pb2.DT_RESOURCE) 238tf_export("dtypes.resource", "resource").export_constant(__name__, "resource") 239variant = DType(types_pb2.DT_VARIANT) 240tf_export("dtypes.variant", "variant").export_constant(__name__, "variant") 241float16 = DType(types_pb2.DT_HALF) 242tf_export("dtypes.float16", "float16").export_constant(__name__, "float16") 243half = float16 244tf_export("dtypes.half", "half").export_constant(__name__, "half") 245float32 = DType(types_pb2.DT_FLOAT) 246tf_export("dtypes.float32", "float32").export_constant(__name__, "float32") 247float64 = DType(types_pb2.DT_DOUBLE) 248tf_export("dtypes.float64", "float64").export_constant(__name__, "float64") 249double = float64 250tf_export("dtypes.double", "double").export_constant(__name__, "double") 251int32 = DType(types_pb2.DT_INT32) 252tf_export("dtypes.int32", "int32").export_constant(__name__, "int32") 253uint8 = DType(types_pb2.DT_UINT8) 254tf_export("dtypes.uint8", "uint8").export_constant(__name__, "uint8") 255uint16 = DType(types_pb2.DT_UINT16) 256tf_export("dtypes.uint16", "uint16").export_constant(__name__, "uint16") 257uint32 = DType(types_pb2.DT_UINT32) 258tf_export("dtypes.uint32", "uint32").export_constant(__name__, "uint32") 259uint64 = DType(types_pb2.DT_UINT64) 260tf_export("dtypes.uint64", "uint64").export_constant(__name__, "uint64") 261int16 = DType(types_pb2.DT_INT16) 262tf_export("dtypes.int16", "int16").export_constant(__name__, "int16") 263int8 = DType(types_pb2.DT_INT8) 264tf_export("dtypes.int8", "int8").export_constant(__name__, "int8") 265string = DType(types_pb2.DT_STRING) 266tf_export("dtypes.string", "string").export_constant(__name__, "string") 267complex64 = DType(types_pb2.DT_COMPLEX64) 268tf_export("dtypes.complex64", 269 "complex64").export_constant(__name__, "complex64") 270complex128 = DType(types_pb2.DT_COMPLEX128) 271tf_export("dtypes.complex128", 272 "complex128").export_constant(__name__, "complex128") 273int64 = DType(types_pb2.DT_INT64) 274tf_export("dtypes.int64", "int64").export_constant(__name__, "int64") 275bool = DType(types_pb2.DT_BOOL) # pylint: disable=redefined-builtin 276tf_export("dtypes.bool", "bool").export_constant(__name__, "bool") 277qint8 = DType(types_pb2.DT_QINT8) 278tf_export("dtypes.qint8", "qint8").export_constant(__name__, "qint8") 279quint8 = DType(types_pb2.DT_QUINT8) 280tf_export("dtypes.quint8", "quint8").export_constant(__name__, "quint8") 281qint16 = DType(types_pb2.DT_QINT16) 282tf_export("dtypes.qint16", "qint16").export_constant(__name__, "qint16") 283quint16 = DType(types_pb2.DT_QUINT16) 284tf_export("dtypes.quint16", "quint16").export_constant(__name__, "quint16") 285qint32 = DType(types_pb2.DT_QINT32) 286tf_export("dtypes.qint32", "qint32").export_constant(__name__, "qint32") 287resource_ref = DType(types_pb2.DT_RESOURCE_REF) 288variant_ref = DType(types_pb2.DT_VARIANT_REF) 289bfloat16 = DType(types_pb2.DT_BFLOAT16) 290tf_export("dtypes.bfloat16", "bfloat16").export_constant(__name__, "bfloat16") 291float16_ref = DType(types_pb2.DT_HALF_REF) 292half_ref = float16_ref 293float32_ref = DType(types_pb2.DT_FLOAT_REF) 294float64_ref = DType(types_pb2.DT_DOUBLE_REF) 295double_ref = float64_ref 296int32_ref = DType(types_pb2.DT_INT32_REF) 297uint32_ref = DType(types_pb2.DT_UINT32_REF) 298uint8_ref = DType(types_pb2.DT_UINT8_REF) 299uint16_ref = DType(types_pb2.DT_UINT16_REF) 300int16_ref = DType(types_pb2.DT_INT16_REF) 301int8_ref = DType(types_pb2.DT_INT8_REF) 302string_ref = DType(types_pb2.DT_STRING_REF) 303complex64_ref = DType(types_pb2.DT_COMPLEX64_REF) 304complex128_ref = DType(types_pb2.DT_COMPLEX128_REF) 305int64_ref = DType(types_pb2.DT_INT64_REF) 306uint64_ref = DType(types_pb2.DT_UINT64_REF) 307bool_ref = DType(types_pb2.DT_BOOL_REF) 308qint8_ref = DType(types_pb2.DT_QINT8_REF) 309quint8_ref = DType(types_pb2.DT_QUINT8_REF) 310qint16_ref = DType(types_pb2.DT_QINT16_REF) 311quint16_ref = DType(types_pb2.DT_QUINT16_REF) 312qint32_ref = DType(types_pb2.DT_QINT32_REF) 313bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF) 314 315# Maintain an intern table so that we don't have to create a large 316# number of small objects. 317_INTERN_TABLE = { 318 types_pb2.DT_HALF: float16, 319 types_pb2.DT_FLOAT: float32, 320 types_pb2.DT_DOUBLE: float64, 321 types_pb2.DT_INT32: int32, 322 types_pb2.DT_UINT8: uint8, 323 types_pb2.DT_UINT16: uint16, 324 types_pb2.DT_UINT32: uint32, 325 types_pb2.DT_UINT64: uint64, 326 types_pb2.DT_INT16: int16, 327 types_pb2.DT_INT8: int8, 328 types_pb2.DT_STRING: string, 329 types_pb2.DT_COMPLEX64: complex64, 330 types_pb2.DT_COMPLEX128: complex128, 331 types_pb2.DT_INT64: int64, 332 types_pb2.DT_BOOL: bool, 333 types_pb2.DT_QINT8: qint8, 334 types_pb2.DT_QUINT8: quint8, 335 types_pb2.DT_QINT16: qint16, 336 types_pb2.DT_QUINT16: quint16, 337 types_pb2.DT_QINT32: qint32, 338 types_pb2.DT_BFLOAT16: bfloat16, 339 types_pb2.DT_RESOURCE: resource, 340 types_pb2.DT_VARIANT: variant, 341 types_pb2.DT_HALF_REF: float16_ref, 342 types_pb2.DT_FLOAT_REF: float32_ref, 343 types_pb2.DT_DOUBLE_REF: float64_ref, 344 types_pb2.DT_INT32_REF: int32_ref, 345 types_pb2.DT_UINT32_REF: uint32_ref, 346 types_pb2.DT_UINT8_REF: uint8_ref, 347 types_pb2.DT_UINT16_REF: uint16_ref, 348 types_pb2.DT_INT16_REF: int16_ref, 349 types_pb2.DT_INT8_REF: int8_ref, 350 types_pb2.DT_STRING_REF: string_ref, 351 types_pb2.DT_COMPLEX64_REF: complex64_ref, 352 types_pb2.DT_COMPLEX128_REF: complex128_ref, 353 types_pb2.DT_INT64_REF: int64_ref, 354 types_pb2.DT_UINT64_REF: uint64_ref, 355 types_pb2.DT_BOOL_REF: bool_ref, 356 types_pb2.DT_QINT8_REF: qint8_ref, 357 types_pb2.DT_QUINT8_REF: quint8_ref, 358 types_pb2.DT_QINT16_REF: qint16_ref, 359 types_pb2.DT_QUINT16_REF: quint16_ref, 360 types_pb2.DT_QINT32_REF: qint32_ref, 361 types_pb2.DT_BFLOAT16_REF: bfloat16_ref, 362 types_pb2.DT_RESOURCE_REF: resource_ref, 363 types_pb2.DT_VARIANT_REF: variant_ref, 364} 365 366# Standard mappings between types_pb2.DataType values and string names. 367_TYPE_TO_STRING = { 368 types_pb2.DT_HALF: "float16", 369 types_pb2.DT_FLOAT: "float32", 370 types_pb2.DT_DOUBLE: "float64", 371 types_pb2.DT_INT32: "int32", 372 types_pb2.DT_UINT8: "uint8", 373 types_pb2.DT_UINT16: "uint16", 374 types_pb2.DT_UINT32: "uint32", 375 types_pb2.DT_UINT64: "uint64", 376 types_pb2.DT_INT16: "int16", 377 types_pb2.DT_INT8: "int8", 378 types_pb2.DT_STRING: "string", 379 types_pb2.DT_COMPLEX64: "complex64", 380 types_pb2.DT_COMPLEX128: "complex128", 381 types_pb2.DT_INT64: "int64", 382 types_pb2.DT_BOOL: "bool", 383 types_pb2.DT_QINT8: "qint8", 384 types_pb2.DT_QUINT8: "quint8", 385 types_pb2.DT_QINT16: "qint16", 386 types_pb2.DT_QUINT16: "quint16", 387 types_pb2.DT_QINT32: "qint32", 388 types_pb2.DT_BFLOAT16: "bfloat16", 389 types_pb2.DT_RESOURCE: "resource", 390 types_pb2.DT_VARIANT: "variant", 391 types_pb2.DT_HALF_REF: "float16_ref", 392 types_pb2.DT_FLOAT_REF: "float32_ref", 393 types_pb2.DT_DOUBLE_REF: "float64_ref", 394 types_pb2.DT_INT32_REF: "int32_ref", 395 types_pb2.DT_UINT32_REF: "uint32_ref", 396 types_pb2.DT_UINT8_REF: "uint8_ref", 397 types_pb2.DT_UINT16_REF: "uint16_ref", 398 types_pb2.DT_INT16_REF: "int16_ref", 399 types_pb2.DT_INT8_REF: "int8_ref", 400 types_pb2.DT_STRING_REF: "string_ref", 401 types_pb2.DT_COMPLEX64_REF: "complex64_ref", 402 types_pb2.DT_COMPLEX128_REF: "complex128_ref", 403 types_pb2.DT_INT64_REF: "int64_ref", 404 types_pb2.DT_UINT64_REF: "uint64_ref", 405 types_pb2.DT_BOOL_REF: "bool_ref", 406 types_pb2.DT_QINT8_REF: "qint8_ref", 407 types_pb2.DT_QUINT8_REF: "quint8_ref", 408 types_pb2.DT_QINT16_REF: "qint16_ref", 409 types_pb2.DT_QUINT16_REF: "quint16_ref", 410 types_pb2.DT_QINT32_REF: "qint32_ref", 411 types_pb2.DT_BFLOAT16_REF: "bfloat16_ref", 412 types_pb2.DT_RESOURCE_REF: "resource_ref", 413 types_pb2.DT_VARIANT_REF: "variant_ref", 414} 415_STRING_TO_TF = { 416 value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items() 417} 418# Add non-canonical aliases. 419_STRING_TO_TF["half"] = float16 420_STRING_TO_TF["half_ref"] = float16_ref 421_STRING_TO_TF["float"] = float32 422_STRING_TO_TF["float_ref"] = float32_ref 423_STRING_TO_TF["double"] = float64 424_STRING_TO_TF["double_ref"] = float64_ref 425 426# Numpy representation for quantized dtypes. 427# 428# These are magic strings that are used in the swig wrapper to identify 429# quantized types. 430# TODO(mrry,keveman): Investigate Numpy type registration to replace this 431# hard-coding of names. 432_np_qint8 = np.dtype([("qint8", np.int8)]) 433_np_quint8 = np.dtype([("quint8", np.uint8)]) 434_np_qint16 = np.dtype([("qint16", np.int16)]) 435_np_quint16 = np.dtype([("quint16", np.uint16)]) 436_np_qint32 = np.dtype([("qint32", np.int32)]) 437 438# _np_bfloat16 is defined by a module import. 439 440# Custom struct dtype for directly-fed ResourceHandles of supported type(s). 441np_resource = np.dtype([("resource", np.ubyte)]) 442 443# Standard mappings between types_pb2.DataType values and numpy.dtypes. 444_NP_TO_TF = { 445 np.float16: float16, 446 np.float32: float32, 447 np.float64: float64, 448 np.int32: int32, 449 np.int64: int64, 450 np.uint8: uint8, 451 np.uint16: uint16, 452 np.uint32: uint32, 453 np.uint64: uint64, 454 np.int16: int16, 455 np.int8: int8, 456 np.complex64: complex64, 457 np.complex128: complex128, 458 np.object_: string, 459 np.string_: string, 460 np.unicode_: string, 461 np.bool_: bool, 462 _np_qint8: qint8, 463 _np_quint8: quint8, 464 _np_qint16: qint16, 465 _np_quint16: quint16, 466 _np_qint32: qint32, 467 _np_bfloat16: bfloat16, 468} 469 470# Map (some) NumPy platform dtypes to TF ones using their fixed-width 471# synonyms. Note that platform dtypes are not always simples aliases, 472# i.e. reference equality is not guaranteed. See e.g. numpy/numpy#9799. 473for pdt in [ 474 np.intc, 475 np.uintc, 476 np.int_, 477 np.uint, 478 np.longlong, 479 np.ulonglong, 480]: 481 if pdt not in _NP_TO_TF: 482 _NP_TO_TF[pdt] = next( 483 _NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype) 484 485 486TF_VALUE_DTYPES = set(_NP_TO_TF.values()) 487 488 489_TF_TO_NP = { 490 types_pb2.DT_HALF: 491 np.float16, 492 types_pb2.DT_FLOAT: 493 np.float32, 494 types_pb2.DT_DOUBLE: 495 np.float64, 496 types_pb2.DT_INT32: 497 np.int32, 498 types_pb2.DT_UINT8: 499 np.uint8, 500 types_pb2.DT_UINT16: 501 np.uint16, 502 types_pb2.DT_UINT32: 503 np.uint32, 504 types_pb2.DT_UINT64: 505 np.uint64, 506 types_pb2.DT_INT16: 507 np.int16, 508 types_pb2.DT_INT8: 509 np.int8, 510 # NOTE(touts): For strings we use np.object as it supports variable length 511 # strings. 512 types_pb2.DT_STRING: 513 np.object, 514 types_pb2.DT_COMPLEX64: 515 np.complex64, 516 types_pb2.DT_COMPLEX128: 517 np.complex128, 518 types_pb2.DT_INT64: 519 np.int64, 520 types_pb2.DT_BOOL: 521 np.bool, 522 types_pb2.DT_QINT8: 523 _np_qint8, 524 types_pb2.DT_QUINT8: 525 _np_quint8, 526 types_pb2.DT_QINT16: 527 _np_qint16, 528 types_pb2.DT_QUINT16: 529 _np_quint16, 530 types_pb2.DT_QINT32: 531 _np_qint32, 532 types_pb2.DT_BFLOAT16: 533 _np_bfloat16, 534 535 # Ref types 536 types_pb2.DT_HALF_REF: 537 np.float16, 538 types_pb2.DT_FLOAT_REF: 539 np.float32, 540 types_pb2.DT_DOUBLE_REF: 541 np.float64, 542 types_pb2.DT_INT32_REF: 543 np.int32, 544 types_pb2.DT_UINT32_REF: 545 np.uint32, 546 types_pb2.DT_UINT8_REF: 547 np.uint8, 548 types_pb2.DT_UINT16_REF: 549 np.uint16, 550 types_pb2.DT_INT16_REF: 551 np.int16, 552 types_pb2.DT_INT8_REF: 553 np.int8, 554 types_pb2.DT_STRING_REF: 555 np.object, 556 types_pb2.DT_COMPLEX64_REF: 557 np.complex64, 558 types_pb2.DT_COMPLEX128_REF: 559 np.complex128, 560 types_pb2.DT_INT64_REF: 561 np.int64, 562 types_pb2.DT_UINT64_REF: 563 np.uint64, 564 types_pb2.DT_BOOL_REF: 565 np.bool, 566 types_pb2.DT_QINT8_REF: 567 _np_qint8, 568 types_pb2.DT_QUINT8_REF: 569 _np_quint8, 570 types_pb2.DT_QINT16_REF: 571 _np_qint16, 572 types_pb2.DT_QUINT16_REF: 573 _np_quint16, 574 types_pb2.DT_QINT32_REF: 575 _np_qint32, 576 types_pb2.DT_BFLOAT16_REF: 577 _np_bfloat16, 578} 579 580_QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32]) 581_QUANTIZED_DTYPES_REF = frozenset( 582 [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref]) 583QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF) 584tf_export( 585 "dtypes.QUANTIZED_DTYPES", 586 v1=["dtypes.QUANTIZED_DTYPES", 587 "QUANTIZED_DTYPES"]).export_constant(__name__, "QUANTIZED_DTYPES") 588 589_PYTHON_TO_TF = { 590 builtins.float: float32, 591 builtins.bool: bool, 592 builtins.object: string 593} 594 595_ANY_TO_TF = {} 596_ANY_TO_TF.update(_INTERN_TABLE) 597_ANY_TO_TF.update(_STRING_TO_TF) 598_ANY_TO_TF.update(_PYTHON_TO_TF) 599_ANY_TO_TF.update(_NP_TO_TF) 600 601# Ensure no collisions. 602assert len(_ANY_TO_TF) == sum( 603 len(d) for d in [_INTERN_TABLE, _STRING_TO_TF, _PYTHON_TO_TF, _NP_TO_TF]) 604 605 606@tf_export("dtypes.as_dtype", "as_dtype") 607def as_dtype(type_value): 608 """Converts the given `type_value` to a `DType`. 609 610 Args: 611 type_value: A value that can be converted to a `tf.DType` object. This may 612 currently be a `tf.DType` object, a [`DataType` 613 enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto), 614 a string type name, or a `numpy.dtype`. 615 616 Returns: 617 A `DType` corresponding to `type_value`. 618 619 Raises: 620 TypeError: If `type_value` cannot be converted to a `DType`. 621 """ 622 if isinstance(type_value, DType): 623 return type_value 624 625 if isinstance(type_value, np.dtype): 626 try: 627 return _NP_TO_TF[type_value.type] 628 except KeyError: 629 pass 630 631 try: 632 return _ANY_TO_TF[type_value] 633 except KeyError: 634 pass 635 636 if hasattr(type_value, "dtype"): 637 try: 638 return _NP_TO_TF[np.dtype(type_value.dtype).type] 639 except (KeyError, TypeError): 640 pass 641 642 raise TypeError("Cannot convert value %r to a TensorFlow DType." % 643 (type_value,)) 644