1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Dtypes and dtype utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.ops.numpy_ops import np_export 25 26 27# We use numpy's dtypes instead of TF's, because the user expects to use them 28# with numpy facilities such as `np.dtype(np.int64)` and 29# `if x.dtype.type is np.int64`. 30bool_ = np_export.np_export_constant(__name__, 'bool_', np.bool_) 31complex_ = np_export.np_export_constant(__name__, 'complex_', np.complex_) 32complex128 = np_export.np_export_constant(__name__, 'complex128', np.complex128) 33complex64 = np_export.np_export_constant(__name__, 'complex64', np.complex64) 34float_ = np_export.np_export_constant(__name__, 'float_', np.float_) 35float16 = np_export.np_export_constant(__name__, 'float16', np.float16) 36float32 = np_export.np_export_constant(__name__, 'float32', np.float32) 37float64 = np_export.np_export_constant(__name__, 'float64', np.float64) 38inexact = np_export.np_export_constant(__name__, 'inexact', np.inexact) 39int_ = np_export.np_export_constant(__name__, 'int_', np.int_) 40int16 = np_export.np_export_constant(__name__, 'int16', np.int16) 41int32 = np_export.np_export_constant(__name__, 'int32', np.int32) 42int64 = np_export.np_export_constant(__name__, 'int64', np.int64) 43int8 = np_export.np_export_constant(__name__, 'int8', np.int8) 44object_ = np_export.np_export_constant(__name__, 'object_', np.object_) 45string_ = np_export.np_export_constant(__name__, 'string_', np.string_) 46uint16 = np_export.np_export_constant(__name__, 'uint16', np.uint16) 47uint32 = np_export.np_export_constant(__name__, 'uint32', np.uint32) 48uint64 = np_export.np_export_constant(__name__, 'uint64', np.uint64) 49uint8 = np_export.np_export_constant(__name__, 'uint8', np.uint8) 50unicode_ = np_export.np_export_constant(__name__, 'unicode_', np.unicode_) 51 52 53iinfo = np_export.np_export_constant(__name__, 'iinfo', np.iinfo) 54 55 56issubdtype = np_export.np_export('issubdtype')(np.issubdtype) 57 58 59_to_float32 = { 60 np.dtype('float64'): np.dtype('float32'), 61 np.dtype('complex128'): np.dtype('complex64'), 62} 63 64 65_cached_np_dtypes = {} 66 67 68# Difference between is_prefer_float32 and is_allow_float64: is_prefer_float32 69# only decides which dtype to use for Python floats; is_allow_float64 decides 70# whether float64 dtypes can ever appear in programs. The latter is more 71# restrictive than the former. 72_prefer_float32 = False 73 74 75# TODO(b/178862061): Consider removing this knob 76_allow_float64 = True 77 78 79def is_prefer_float32(): 80 return _prefer_float32 81 82 83def set_prefer_float32(b): 84 global _prefer_float32 85 _prefer_float32 = b 86 87 88def is_allow_float64(): 89 return _allow_float64 90 91 92def set_allow_float64(b): 93 global _allow_float64 94 _allow_float64 = b 95 96 97def canonicalize_dtype(dtype): 98 if not _allow_float64: 99 try: 100 return _to_float32[dtype] 101 except KeyError: 102 pass 103 return dtype 104 105 106def _result_type(*arrays_and_dtypes): 107 def preprocess_float(x): 108 if is_prefer_float32() and isinstance(x, float): 109 return np.float32(x) 110 return x 111 arrays_and_dtypes = [preprocess_float(x) for x in arrays_and_dtypes] 112 dtype = np.result_type(*arrays_and_dtypes) 113 return dtypes.as_dtype(canonicalize_dtype(dtype)) 114 115 116def _get_cached_dtype(dtype): 117 """Returns an np.dtype for the TensorFlow DType.""" 118 global _cached_np_dtypes 119 try: 120 return _cached_np_dtypes[dtype] 121 except KeyError: 122 pass 123 cached_dtype = np.dtype(dtype.as_numpy_dtype) 124 _cached_np_dtypes[dtype] = cached_dtype 125 return cached_dtype 126 127 128def default_float_type(): 129 """Gets the default float type. 130 131 Returns: 132 If `is_prefer_float32()` is false and `is_allow_float64()` is true, returns 133 float64; otherwise returns float32. 134 """ 135 if not is_prefer_float32() and is_allow_float64(): 136 return float64 137 else: 138 return float32 139