• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library of dtypes (Tensor element types)."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.core.framework import types_pb2
23from tensorflow.python import pywrap_tensorflow
24from tensorflow.python.util.tf_export import tf_export
25
26_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
27
28
29@tf_export("DType")
30class DType(object):
31  """Represents the type of the elements in a `Tensor`.
32
33  The following `DType` objects are defined:
34
35  * `tf.float16`: 16-bit half-precision floating-point.
36  * `tf.float32`: 32-bit single-precision floating-point.
37  * `tf.float64`: 64-bit double-precision floating-point.
38  * `tf.bfloat16`: 16-bit truncated floating-point.
39  * `tf.complex64`: 64-bit single-precision complex.
40  * `tf.complex128`: 128-bit double-precision complex.
41  * `tf.int8`: 8-bit signed integer.
42  * `tf.uint8`: 8-bit unsigned integer.
43  * `tf.uint16`: 16-bit unsigned integer.
44  * `tf.uint32`: 32-bit unsigned integer.
45  * `tf.uint64`: 64-bit unsigned integer.
46  * `tf.int16`: 16-bit signed integer.
47  * `tf.int32`: 32-bit signed integer.
48  * `tf.int64`: 64-bit signed integer.
49  * `tf.bool`: Boolean.
50  * `tf.string`: String.
51  * `tf.qint8`: Quantized 8-bit signed integer.
52  * `tf.quint8`: Quantized 8-bit unsigned integer.
53  * `tf.qint16`: Quantized 16-bit signed integer.
54  * `tf.quint16`: Quantized 16-bit unsigned integer.
55  * `tf.qint32`: Quantized 32-bit signed integer.
56  * `tf.resource`: Handle to a mutable resource.
57  * `tf.variant`: Values of arbitrary types.
58
59  In addition, variants of these types with the `_ref` suffix are
60  defined for reference-typed tensors.
61
62  The `tf.as_dtype()` function converts numpy types and string type
63  names to a `DType` object.
64  """
65
66  def __init__(self, type_enum):
67    """Creates a new `DataType`.
68
69    NOTE(mrry): In normal circumstances, you should not need to
70    construct a `DataType` object directly. Instead, use the
71    `tf.as_dtype()` function.
72
73    Args:
74      type_enum: A `types_pb2.DataType` enum value.
75
76    Raises:
77      TypeError: If `type_enum` is not a value `types_pb2.DataType`.
78
79    """
80    # TODO(mrry): Make the necessary changes (using __new__) to ensure
81    # that calling this returns one of the interned values.
82    type_enum = int(type_enum)
83    if (type_enum not in types_pb2.DataType.values() or
84        type_enum == types_pb2.DT_INVALID):
85      raise TypeError(
86          "type_enum is not a valid types_pb2.DataType: %s" % type_enum)
87    self._type_enum = type_enum
88
89  @property
90  def _is_ref_dtype(self):
91    """Returns `True` if this `DType` represents a reference type."""
92    return self._type_enum > 100
93
94  @property
95  def _as_ref(self):
96    """Returns a reference `DType` based on this `DType`."""
97    if self._is_ref_dtype:
98      return self
99    else:
100      return _INTERN_TABLE[self._type_enum + 100]
101
102  @property
103  def base_dtype(self):
104    """Returns a non-reference `DType` based on this `DType`."""
105    if self._is_ref_dtype:
106      return _INTERN_TABLE[self._type_enum - 100]
107    else:
108      return self
109
110  @property
111  def real_dtype(self):
112    """Returns the dtype correspond to this dtype's real part."""
113    base = self.base_dtype
114    if base == complex64:
115      return float32
116    elif base == complex128:
117      return float64
118    else:
119      return self
120
121  @property
122  def is_numpy_compatible(self):
123    numpy_incompatible = [
124        types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE,
125        types_pb2.DT_RESOURCE_REF
126    ]
127    return self._type_enum not in numpy_incompatible
128
129  @property
130  def as_numpy_dtype(self):
131    """Returns a `numpy.dtype` based on this `DType`."""
132    return _TF_TO_NP[self._type_enum]
133
134  @property
135  def as_datatype_enum(self):
136    """Returns a `types_pb2.DataType` enum value based on this `DType`."""
137    return self._type_enum
138
139  @property
140  def is_bool(self):
141    """Returns whether this is a boolean data type"""
142    return self.base_dtype == bool
143
144  @property
145  def is_integer(self):
146    """Returns whether this is a (non-quantized) integer type."""
147    return (self.is_numpy_compatible and not self.is_quantized and
148            np.issubdtype(self.as_numpy_dtype, np.integer))
149
150  @property
151  def is_floating(self):
152    """Returns whether this is a (non-quantized, real) floating point type."""
153    return ((self.is_numpy_compatible and
154             np.issubdtype(self.as_numpy_dtype, np.floating)) or
155            self.base_dtype == bfloat16)
156
157  @property
158  def is_complex(self):
159    """Returns whether this is a complex floating point type."""
160    return self.base_dtype in (complex64, complex128)
161
162  @property
163  def is_quantized(self):
164    """Returns whether this is a quantized data type."""
165    return self.base_dtype in [qint8, quint8, qint16, quint16, qint32]
166
167  @property
168  def is_unsigned(self):
169    """Returns whether this type is unsigned.
170
171    Non-numeric, unordered, and quantized types are not considered unsigned, and
172    this function returns `False`.
173
174    Returns:
175      Whether a `DType` is unsigned.
176    """
177    try:
178      return self.min == 0
179    except TypeError:
180      return False
181
182  @property
183  def min(self):
184    """Returns the minimum representable value in this data type.
185
186    Raises:
187      TypeError: if this is a non-numeric, unordered, or quantized type.
188
189    """
190    if (self.is_quantized or
191        self.base_dtype in (bool, string, complex64, complex128)):
192      raise TypeError("Cannot find minimum value of %s." % self)
193
194    # there is no simple way to get the min value of a dtype, we have to check
195    # float and int types separately
196    try:
197      return np.finfo(self.as_numpy_dtype()).min
198    except:  # bare except as possible raises by finfo not documented
199      try:
200        return np.iinfo(self.as_numpy_dtype()).min
201      except:
202        if self.base_dtype == bfloat16:
203          return _np_bfloat16(float.fromhex("-0x1.FEp127"))
204        raise TypeError("Cannot find minimum value of %s." % self)
205
206  @property
207  def max(self):
208    """Returns the maximum representable value in this data type.
209
210    Raises:
211      TypeError: if this is a non-numeric, unordered, or quantized type.
212
213    """
214    if (self.is_quantized or
215        self.base_dtype in (bool, string, complex64, complex128)):
216      raise TypeError("Cannot find maximum value of %s." % self)
217
218    # there is no simple way to get the max value of a dtype, we have to check
219    # float and int types separately
220    try:
221      return np.finfo(self.as_numpy_dtype()).max
222    except:  # bare except as possible raises by finfo not documented
223      try:
224        return np.iinfo(self.as_numpy_dtype()).max
225      except:
226        if self.base_dtype == bfloat16:
227          return _np_bfloat16(float.fromhex("0x1.FEp127"))
228        raise TypeError("Cannot find maximum value of %s." % self)
229
230  @property
231  def limits(self, clip_negative=True):
232    """Return intensity limits, i.e. (min, max) tuple, of the dtype.
233    Args:
234      clip_negative : bool, optional
235          If True, clip the negative range (i.e. return 0 for min intensity)
236          even if the image dtype allows negative values.
237    Returns
238      min, max : tuple
239        Lower and upper intensity limits.
240    """
241    min, max = dtype_range[self.as_numpy_dtype]  # pylint: disable=redefined-builtin
242    if clip_negative:
243      min = 0  # pylint: disable=redefined-builtin
244    return min, max
245
246  def is_compatible_with(self, other):
247    """Returns True if the `other` DType will be converted to this DType.
248
249    The conversion rules are as follows:
250
251    ```python
252    DType(T)       .is_compatible_with(DType(T))        == True
253    DType(T)       .is_compatible_with(DType(T).as_ref) == True
254    DType(T).as_ref.is_compatible_with(DType(T))        == False
255    DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True
256    ```
257
258    Args:
259      other: A `DType` (or object that may be converted to a `DType`).
260
261    Returns:
262      True if a Tensor of the `other` `DType` will be implicitly converted to
263      this `DType`.
264    """
265    other = as_dtype(other)
266    return self._type_enum in (other.as_datatype_enum,
267                               other.base_dtype.as_datatype_enum)
268
269  def __eq__(self, other):
270    """Returns True iff this DType refers to the same type as `other`."""
271    if other is None:
272      return False
273    try:
274      dtype = as_dtype(other).as_datatype_enum
275      return self._type_enum == dtype  # pylint: disable=protected-access
276    except TypeError:
277      return False
278
279  def __ne__(self, other):
280    """Returns True iff self != other."""
281    return not self.__eq__(other)
282
283  @property
284  def name(self):
285    """Returns the string name for this `DType`."""
286    return _TYPE_TO_STRING[self._type_enum]
287
288  def __int__(self):
289    return self._type_enum
290
291  def __str__(self):
292    return "<dtype: %r>" % self.name
293
294  def __repr__(self):
295    return "tf." + self.name
296
297  def __hash__(self):
298    return self._type_enum
299
300  @property
301  def size(self):
302    if (self._type_enum == types_pb2.DT_VARIANT or
303        self._type_enum == types_pb2.DT_RESOURCE):
304      return 1
305    return np.dtype(self.as_numpy_dtype).itemsize
306
307
308# Define data type range of numpy dtype
309dtype_range = {
310    np.bool_: (False, True),
311    np.bool8: (False, True),
312    np.uint8: (0, 255),
313    np.uint16: (0, 65535),
314    np.int8: (-128, 127),
315    np.int16: (-32768, 32767),
316    np.int64: (-2**63, 2**63 - 1),
317    np.uint64: (0, 2**64 - 1),
318    np.int32: (-2**31, 2**31 - 1),
319    np.uint32: (0, 2**32 - 1),
320    np.float32: (-1, 1),
321    np.float64: (-1, 1)
322}
323
324# Define standard wrappers for the types_pb2.DataType enum.
325resource = DType(types_pb2.DT_RESOURCE)
326tf_export("resource").export_constant(__name__, "resource")
327variant = DType(types_pb2.DT_VARIANT)
328tf_export("variant").export_constant(__name__, "variant")
329float16 = DType(types_pb2.DT_HALF)
330tf_export("float16").export_constant(__name__, "float16")
331half = float16
332tf_export("half").export_constant(__name__, "half")
333float32 = DType(types_pb2.DT_FLOAT)
334tf_export("float32").export_constant(__name__, "float32")
335float64 = DType(types_pb2.DT_DOUBLE)
336tf_export("float64").export_constant(__name__, "float64")
337double = float64
338tf_export("double").export_constant(__name__, "double")
339int32 = DType(types_pb2.DT_INT32)
340tf_export("int32").export_constant(__name__, "int32")
341uint8 = DType(types_pb2.DT_UINT8)
342tf_export("uint8").export_constant(__name__, "uint8")
343uint16 = DType(types_pb2.DT_UINT16)
344tf_export("uint16").export_constant(__name__, "uint16")
345uint32 = DType(types_pb2.DT_UINT32)
346uint64 = DType(types_pb2.DT_UINT64)
347int16 = DType(types_pb2.DT_INT16)
348tf_export("int16").export_constant(__name__, "int16")
349int8 = DType(types_pb2.DT_INT8)
350tf_export("int8").export_constant(__name__, "int8")
351string = DType(types_pb2.DT_STRING)
352tf_export("string").export_constant(__name__, "string")
353complex64 = DType(types_pb2.DT_COMPLEX64)
354tf_export("complex64").export_constant(__name__, "complex64")
355complex128 = DType(types_pb2.DT_COMPLEX128)
356tf_export("complex128").export_constant(__name__, "complex128")
357int64 = DType(types_pb2.DT_INT64)
358tf_export("int64").export_constant(__name__, "int64")
359bool = DType(types_pb2.DT_BOOL)  # pylint: disable=redefined-builtin
360tf_export("bool").export_constant(__name__, "bool")
361qint8 = DType(types_pb2.DT_QINT8)
362tf_export("qint8").export_constant(__name__, "qint8")
363quint8 = DType(types_pb2.DT_QUINT8)
364tf_export("quint8").export_constant(__name__, "quint8")
365qint16 = DType(types_pb2.DT_QINT16)
366tf_export("qint16").export_constant(__name__, "qint16")
367quint16 = DType(types_pb2.DT_QUINT16)
368tf_export("quint16").export_constant(__name__, "quint16")
369qint32 = DType(types_pb2.DT_QINT32)
370tf_export("qint32").export_constant(__name__, "qint32")
371resource_ref = DType(types_pb2.DT_RESOURCE_REF)
372variant_ref = DType(types_pb2.DT_VARIANT_REF)
373bfloat16 = DType(types_pb2.DT_BFLOAT16)
374tf_export("bfloat16").export_constant(__name__, "bfloat16")
375float16_ref = DType(types_pb2.DT_HALF_REF)
376half_ref = float16_ref
377float32_ref = DType(types_pb2.DT_FLOAT_REF)
378float64_ref = DType(types_pb2.DT_DOUBLE_REF)
379double_ref = float64_ref
380int32_ref = DType(types_pb2.DT_INT32_REF)
381uint32_ref = DType(types_pb2.DT_UINT32_REF)
382uint8_ref = DType(types_pb2.DT_UINT8_REF)
383uint16_ref = DType(types_pb2.DT_UINT16_REF)
384int16_ref = DType(types_pb2.DT_INT16_REF)
385int8_ref = DType(types_pb2.DT_INT8_REF)
386string_ref = DType(types_pb2.DT_STRING_REF)
387complex64_ref = DType(types_pb2.DT_COMPLEX64_REF)
388complex128_ref = DType(types_pb2.DT_COMPLEX128_REF)
389int64_ref = DType(types_pb2.DT_INT64_REF)
390uint64_ref = DType(types_pb2.DT_UINT64_REF)
391bool_ref = DType(types_pb2.DT_BOOL_REF)
392qint8_ref = DType(types_pb2.DT_QINT8_REF)
393quint8_ref = DType(types_pb2.DT_QUINT8_REF)
394qint16_ref = DType(types_pb2.DT_QINT16_REF)
395quint16_ref = DType(types_pb2.DT_QUINT16_REF)
396qint32_ref = DType(types_pb2.DT_QINT32_REF)
397bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
398
399# Maintain an intern table so that we don't have to create a large
400# number of small objects.
401_INTERN_TABLE = {
402    types_pb2.DT_HALF: float16,
403    types_pb2.DT_FLOAT: float32,
404    types_pb2.DT_DOUBLE: float64,
405    types_pb2.DT_INT32: int32,
406    types_pb2.DT_UINT8: uint8,
407    types_pb2.DT_UINT16: uint16,
408    types_pb2.DT_UINT32: uint32,
409    types_pb2.DT_UINT64: uint64,
410    types_pb2.DT_INT16: int16,
411    types_pb2.DT_INT8: int8,
412    types_pb2.DT_STRING: string,
413    types_pb2.DT_COMPLEX64: complex64,
414    types_pb2.DT_COMPLEX128: complex128,
415    types_pb2.DT_INT64: int64,
416    types_pb2.DT_BOOL: bool,
417    types_pb2.DT_QINT8: qint8,
418    types_pb2.DT_QUINT8: quint8,
419    types_pb2.DT_QINT16: qint16,
420    types_pb2.DT_QUINT16: quint16,
421    types_pb2.DT_QINT32: qint32,
422    types_pb2.DT_BFLOAT16: bfloat16,
423    types_pb2.DT_RESOURCE: resource,
424    types_pb2.DT_VARIANT: variant,
425    types_pb2.DT_HALF_REF: float16_ref,
426    types_pb2.DT_FLOAT_REF: float32_ref,
427    types_pb2.DT_DOUBLE_REF: float64_ref,
428    types_pb2.DT_INT32_REF: int32_ref,
429    types_pb2.DT_UINT32_REF: uint32_ref,
430    types_pb2.DT_UINT8_REF: uint8_ref,
431    types_pb2.DT_UINT16_REF: uint16_ref,
432    types_pb2.DT_INT16_REF: int16_ref,
433    types_pb2.DT_INT8_REF: int8_ref,
434    types_pb2.DT_STRING_REF: string_ref,
435    types_pb2.DT_COMPLEX64_REF: complex64_ref,
436    types_pb2.DT_COMPLEX128_REF: complex128_ref,
437    types_pb2.DT_INT64_REF: int64_ref,
438    types_pb2.DT_UINT64_REF: uint64_ref,
439    types_pb2.DT_BOOL_REF: bool_ref,
440    types_pb2.DT_QINT8_REF: qint8_ref,
441    types_pb2.DT_QUINT8_REF: quint8_ref,
442    types_pb2.DT_QINT16_REF: qint16_ref,
443    types_pb2.DT_QUINT16_REF: quint16_ref,
444    types_pb2.DT_QINT32_REF: qint32_ref,
445    types_pb2.DT_BFLOAT16_REF: bfloat16_ref,
446    types_pb2.DT_RESOURCE_REF: resource_ref,
447    types_pb2.DT_VARIANT_REF: variant_ref,
448}
449
450# Standard mappings between types_pb2.DataType values and string names.
451_TYPE_TO_STRING = {
452    types_pb2.DT_HALF: "float16",
453    types_pb2.DT_FLOAT: "float32",
454    types_pb2.DT_DOUBLE: "float64",
455    types_pb2.DT_INT32: "int32",
456    types_pb2.DT_UINT8: "uint8",
457    types_pb2.DT_UINT16: "uint16",
458    types_pb2.DT_UINT32: "uint32",
459    types_pb2.DT_UINT64: "uint64",
460    types_pb2.DT_INT16: "int16",
461    types_pb2.DT_INT8: "int8",
462    types_pb2.DT_STRING: "string",
463    types_pb2.DT_COMPLEX64: "complex64",
464    types_pb2.DT_COMPLEX128: "complex128",
465    types_pb2.DT_INT64: "int64",
466    types_pb2.DT_BOOL: "bool",
467    types_pb2.DT_QINT8: "qint8",
468    types_pb2.DT_QUINT8: "quint8",
469    types_pb2.DT_QINT16: "qint16",
470    types_pb2.DT_QUINT16: "quint16",
471    types_pb2.DT_QINT32: "qint32",
472    types_pb2.DT_BFLOAT16: "bfloat16",
473    types_pb2.DT_RESOURCE: "resource",
474    types_pb2.DT_VARIANT: "variant",
475    types_pb2.DT_HALF_REF: "float16_ref",
476    types_pb2.DT_FLOAT_REF: "float32_ref",
477    types_pb2.DT_DOUBLE_REF: "float64_ref",
478    types_pb2.DT_INT32_REF: "int32_ref",
479    types_pb2.DT_UINT32_REF: "uint32_ref",
480    types_pb2.DT_UINT8_REF: "uint8_ref",
481    types_pb2.DT_UINT16_REF: "uint16_ref",
482    types_pb2.DT_INT16_REF: "int16_ref",
483    types_pb2.DT_INT8_REF: "int8_ref",
484    types_pb2.DT_STRING_REF: "string_ref",
485    types_pb2.DT_COMPLEX64_REF: "complex64_ref",
486    types_pb2.DT_COMPLEX128_REF: "complex128_ref",
487    types_pb2.DT_INT64_REF: "int64_ref",
488    types_pb2.DT_UINT64_REF: "uint64_ref",
489    types_pb2.DT_BOOL_REF: "bool_ref",
490    types_pb2.DT_QINT8_REF: "qint8_ref",
491    types_pb2.DT_QUINT8_REF: "quint8_ref",
492    types_pb2.DT_QINT16_REF: "qint16_ref",
493    types_pb2.DT_QUINT16_REF: "quint16_ref",
494    types_pb2.DT_QINT32_REF: "qint32_ref",
495    types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
496    types_pb2.DT_RESOURCE_REF: "resource_ref",
497    types_pb2.DT_VARIANT_REF: "variant_ref",
498}
499_STRING_TO_TF = {
500    value: _INTERN_TABLE[key]
501    for key, value in _TYPE_TO_STRING.items()
502}
503# Add non-canonical aliases.
504_STRING_TO_TF["half"] = float16
505_STRING_TO_TF["half_ref"] = float16_ref
506_STRING_TO_TF["float"] = float32
507_STRING_TO_TF["float_ref"] = float32_ref
508_STRING_TO_TF["double"] = float64
509_STRING_TO_TF["double_ref"] = float64_ref
510
511# Numpy representation for quantized dtypes.
512#
513# These are magic strings that are used in the swig wrapper to identify
514# quantized types.
515# TODO(mrry,keveman): Investigate Numpy type registration to replace this
516# hard-coding of names.
517_np_qint8 = np.dtype([("qint8", np.int8, 1)])
518_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
519_np_qint16 = np.dtype([("qint16", np.int16, 1)])
520_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
521_np_qint32 = np.dtype([("qint32", np.int32, 1)])
522
523# _np_bfloat16 is defined by a module import.
524
525# Custom struct dtype for directly-fed ResourceHandles of supported type(s).
526np_resource = np.dtype([("resource", np.ubyte, 1)])
527
528# Standard mappings between types_pb2.DataType values and numpy.dtypes.
529_NP_TO_TF = frozenset([
530    (np.float16, float16),
531    (np.float32, float32),
532    (np.float64, float64),
533    (np.int32, int32),
534    (np.int64, int64),
535    (np.uint8, uint8),
536    (np.uint16, uint16),
537    (np.uint32, uint32),
538    (np.uint64, uint64),
539    (np.int16, int16),
540    (np.int8, int8),
541    (np.complex64, complex64),
542    (np.complex128, complex128),
543    (np.object, string),
544    (np.bool, bool),
545    (_np_qint8, qint8),
546    (_np_quint8, quint8),
547    (_np_qint16, qint16),
548    (_np_quint16, quint16),
549    (_np_qint32, qint32),
550    (_np_bfloat16, bfloat16),
551])
552_TF_TO_NP = {
553    types_pb2.DT_HALF:
554        np.float16,
555    types_pb2.DT_FLOAT:
556        np.float32,
557    types_pb2.DT_DOUBLE:
558        np.float64,
559    types_pb2.DT_INT32:
560        np.int32,
561    types_pb2.DT_UINT8:
562        np.uint8,
563    types_pb2.DT_UINT16:
564        np.uint16,
565    types_pb2.DT_UINT32:
566        np.uint32,
567    types_pb2.DT_UINT64:
568        np.uint64,
569    types_pb2.DT_INT16:
570        np.int16,
571    types_pb2.DT_INT8:
572        np.int8,
573    # NOTE(touts): For strings we use np.object as it supports variable length
574    # strings.
575    types_pb2.DT_STRING:
576        np.object,
577    types_pb2.DT_COMPLEX64:
578        np.complex64,
579    types_pb2.DT_COMPLEX128:
580        np.complex128,
581    types_pb2.DT_INT64:
582        np.int64,
583    types_pb2.DT_BOOL:
584        np.bool,
585    types_pb2.DT_QINT8:
586        _np_qint8,
587    types_pb2.DT_QUINT8:
588        _np_quint8,
589    types_pb2.DT_QINT16:
590        _np_qint16,
591    types_pb2.DT_QUINT16:
592        _np_quint16,
593    types_pb2.DT_QINT32:
594        _np_qint32,
595    types_pb2.DT_BFLOAT16:
596        _np_bfloat16,
597
598    # Ref types
599    types_pb2.DT_HALF_REF:
600        np.float16,
601    types_pb2.DT_FLOAT_REF:
602        np.float32,
603    types_pb2.DT_DOUBLE_REF:
604        np.float64,
605    types_pb2.DT_INT32_REF:
606        np.int32,
607    types_pb2.DT_UINT32_REF:
608        np.uint32,
609    types_pb2.DT_UINT8_REF:
610        np.uint8,
611    types_pb2.DT_UINT16_REF:
612        np.uint16,
613    types_pb2.DT_INT16_REF:
614        np.int16,
615    types_pb2.DT_INT8_REF:
616        np.int8,
617    types_pb2.DT_STRING_REF:
618        np.object,
619    types_pb2.DT_COMPLEX64_REF:
620        np.complex64,
621    types_pb2.DT_COMPLEX128_REF:
622        np.complex128,
623    types_pb2.DT_INT64_REF:
624        np.int64,
625    types_pb2.DT_UINT64_REF:
626        np.uint64,
627    types_pb2.DT_BOOL_REF:
628        np.bool,
629    types_pb2.DT_QINT8_REF:
630        _np_qint8,
631    types_pb2.DT_QUINT8_REF:
632        _np_quint8,
633    types_pb2.DT_QINT16_REF:
634        _np_qint16,
635    types_pb2.DT_QUINT16_REF:
636        _np_quint16,
637    types_pb2.DT_QINT32_REF:
638        _np_qint32,
639    types_pb2.DT_BFLOAT16_REF:
640        _np_bfloat16,
641}
642
643QUANTIZED_DTYPES = frozenset([
644    qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref,
645    quint16_ref, qint32_ref
646])
647tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES")
648
649
650@tf_export("as_dtype")
651def as_dtype(type_value):
652  """Converts the given `type_value` to a `DType`.
653
654  Args:
655    type_value: A value that can be converted to a `tf.DType`
656      object. This may currently be a `tf.DType` object, a
657      [`DataType`
658        enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto),
659      a string type name, or a `numpy.dtype`.
660
661  Returns:
662    A `DType` corresponding to `type_value`.
663
664  Raises:
665    TypeError: If `type_value` cannot be converted to a `DType`.
666  """
667  if isinstance(type_value, DType):
668    return type_value
669
670  try:
671    return _INTERN_TABLE[type_value]
672  except KeyError:
673    pass
674
675  try:
676    return _STRING_TO_TF[type_value]
677  except KeyError:
678    pass
679
680  if isinstance(type_value, np.dtype):
681    # The numpy dtype for strings is variable length. We can not compare
682    # dtype with a single constant (np.string does not exist) to decide
683    # dtype is a "string" type. We need to compare the dtype.type to be
684    # sure it's a string type.
685    if type_value.type == np.string_ or type_value.type == np.unicode_:
686      return string
687
688  for key, val in _NP_TO_TF:
689    try:
690      if key == type_value:
691        return val
692    except TypeError as e:
693      raise TypeError("Cannot convert {} to a dtype. {}".format(type_value, e))
694
695  raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value)
696