1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/types.h"
17 #include "tensorflow/core/framework/register_types.h"
18 #include "tensorflow/core/lib/strings/str_util.h"
19 #include "tensorflow/core/lib/strings/strcat.h"
20 #include "tensorflow/core/platform/logging.h"
21
22 namespace tensorflow {
23
24 struct DataTypeHasher {
operator ()tensorflow::DataTypeHasher25 std::size_t operator()(const DataType& k) const {
26 return std::hash<int>()(static_cast<int>(k));
27 }
28 };
29
30 // Mapping from some of the DType fields, for backward compatibility. All other
31 // dtypes are mapped to TFT_ANY, but can be added here if a counterpart is
32 // defined.
33 auto* DT_TO_FT = new std::unordered_map<DataType, FullTypeId, DataTypeHasher>({
34 {DT_FLOAT, TFT_FLOAT},
35 {DT_DOUBLE, TFT_DOUBLE},
36 {DT_INT32, TFT_INT32},
37 {DT_UINT8, TFT_UINT8},
38 {DT_INT16, TFT_INT16},
39 {DT_INT8, TFT_INT8},
40 {DT_STRING, TFT_STRING},
41 {DT_COMPLEX64, TFT_COMPLEX64},
42 {DT_INT64, TFT_INT64},
43 {DT_BOOL, TFT_BOOL},
44 {DT_UINT16, TFT_UINT16},
45 {DT_COMPLEX128, TFT_COMPLEX128},
46 {DT_HALF, TFT_HALF},
47 {DT_UINT32, TFT_UINT32},
48 {DT_UINT64, TFT_UINT64},
49 {DT_VARIANT, TFT_LEGACY_VARIANT},
50 });
51
map_dtype_to_tensor(const DataType & dtype,FullTypeDef & t)52 void map_dtype_to_tensor(const DataType& dtype, FullTypeDef& t) {
53 t.Clear();
54
55 const auto& mapped = DT_TO_FT->find(dtype);
56 // Only map known types, everything else remains unset. This is so that we
57 // only set the most specific type when it is fully known. For example, if the
58 // dtype is DT_VARIANT, then we don't know much and opt to assume that
59 // the type is unset, rather than TFT_ANY.
60 if (mapped != DT_TO_FT->end()) {
61 t.set_type_id(mapped->second);
62 }
63 }
64
operator <(const DeviceType & other) const65 bool DeviceType::operator<(const DeviceType& other) const {
66 return type_ < other.type_;
67 }
68
operator ==(const DeviceType & other) const69 bool DeviceType::operator==(const DeviceType& other) const {
70 return type_ == other.type_;
71 }
72
operator <<(std::ostream & os,const DeviceType & d)73 std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
74 os << d.type();
75 return os;
76 }
77
78 const char* const DEVICE_DEFAULT = "DEFAULT";
79 const char* const DEVICE_CPU = "CPU";
80 const char* const DEVICE_GPU = "GPU";
81 const char* const DEVICE_TPU = "TPU";
82 const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
83
84 const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
85 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
86 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
87 const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
88 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
89
90 namespace {
DataTypeStringInternal(DataType dtype)91 string DataTypeStringInternal(DataType dtype) {
92 switch (dtype) {
93 case DT_INVALID:
94 return "INVALID";
95 case DT_FLOAT:
96 return "float";
97 case DT_DOUBLE:
98 return "double";
99 case DT_INT32:
100 return "int32";
101 case DT_UINT32:
102 return "uint32";
103 case DT_UINT8:
104 return "uint8";
105 case DT_UINT16:
106 return "uint16";
107 case DT_INT16:
108 return "int16";
109 case DT_INT8:
110 return "int8";
111 case DT_STRING:
112 return "string";
113 case DT_COMPLEX64:
114 return "complex64";
115 case DT_COMPLEX128:
116 return "complex128";
117 case DT_INT64:
118 return "int64";
119 case DT_UINT64:
120 return "uint64";
121 case DT_BOOL:
122 return "bool";
123 case DT_QINT8:
124 return "qint8";
125 case DT_QUINT8:
126 return "quint8";
127 case DT_QUINT16:
128 return "quint16";
129 case DT_QINT16:
130 return "qint16";
131 case DT_QINT32:
132 return "qint32";
133 case DT_BFLOAT16:
134 return "bfloat16";
135 case DT_HALF:
136 return "half";
137 case DT_RESOURCE:
138 return "resource";
139 case DT_VARIANT:
140 return "variant";
141 default:
142 LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
143 return strings::StrCat("unknown dtype enum (", dtype, ")");
144 }
145 }
146 } // end namespace
147
DataTypeString(DataType dtype)148 string DataTypeString(DataType dtype) {
149 if (IsRefType(dtype)) {
150 DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
151 return strings::StrCat(DataTypeStringInternal(non_ref), "_ref");
152 }
153 return DataTypeStringInternal(dtype);
154 }
155
DataTypeFromString(StringPiece sp,DataType * dt)156 bool DataTypeFromString(StringPiece sp, DataType* dt) {
157 if (str_util::EndsWith(sp, "_ref")) {
158 sp.remove_suffix(4);
159 DataType non_ref;
160 if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
161 *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
162 return true;
163 } else {
164 return false;
165 }
166 }
167
168 if (sp == "float" || sp == "float32") {
169 *dt = DT_FLOAT;
170 return true;
171 } else if (sp == "double" || sp == "float64") {
172 *dt = DT_DOUBLE;
173 return true;
174 } else if (sp == "int32") {
175 *dt = DT_INT32;
176 return true;
177 } else if (sp == "uint32") {
178 *dt = DT_UINT32;
179 return true;
180 } else if (sp == "uint8") {
181 *dt = DT_UINT8;
182 return true;
183 } else if (sp == "uint16") {
184 *dt = DT_UINT16;
185 return true;
186 } else if (sp == "int16") {
187 *dt = DT_INT16;
188 return true;
189 } else if (sp == "int8") {
190 *dt = DT_INT8;
191 return true;
192 } else if (sp == "string") {
193 *dt = DT_STRING;
194 return true;
195 } else if (sp == "complex64") {
196 *dt = DT_COMPLEX64;
197 return true;
198 } else if (sp == "complex128") {
199 *dt = DT_COMPLEX128;
200 return true;
201 } else if (sp == "int64") {
202 *dt = DT_INT64;
203 return true;
204 } else if (sp == "uint64") {
205 *dt = DT_UINT64;
206 return true;
207 } else if (sp == "bool") {
208 *dt = DT_BOOL;
209 return true;
210 } else if (sp == "qint8") {
211 *dt = DT_QINT8;
212 return true;
213 } else if (sp == "quint8") {
214 *dt = DT_QUINT8;
215 return true;
216 } else if (sp == "qint16") {
217 *dt = DT_QINT16;
218 return true;
219 } else if (sp == "quint16") {
220 *dt = DT_QUINT16;
221 return true;
222 } else if (sp == "qint32") {
223 *dt = DT_QINT32;
224 return true;
225 } else if (sp == "bfloat16") {
226 *dt = DT_BFLOAT16;
227 return true;
228 } else if (sp == "half" || sp == "float16") {
229 *dt = DT_HALF;
230 return true;
231 } else if (sp == "resource") {
232 *dt = DT_RESOURCE;
233 return true;
234 } else if (sp == "variant") {
235 *dt = DT_VARIANT;
236 return true;
237 }
238 return false;
239 }
240
DeviceTypeString(const DeviceType & device_type)241 string DeviceTypeString(const DeviceType& device_type) {
242 return device_type.type();
243 }
244
DataTypeSliceString(const DataTypeSlice types)245 string DataTypeSliceString(const DataTypeSlice types) {
246 string out;
247 for (auto it = types.begin(); it != types.end(); ++it) {
248 strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
249 DataTypeString(*it));
250 }
251 return out;
252 }
253
DataTypeAlwaysOnHost(DataType dt)254 bool DataTypeAlwaysOnHost(DataType dt) {
255 // Includes DT_STRING and DT_RESOURCE.
256 switch (dt) {
257 case DT_STRING:
258 case DT_STRING_REF:
259 case DT_RESOURCE:
260 return true;
261 default:
262 return false;
263 }
264 }
265
DataTypeSize(DataType dt)266 int DataTypeSize(DataType dt) {
267 #define CASE(T) \
268 case DataTypeToEnum<T>::value: \
269 return sizeof(T);
270 switch (dt) {
271 TF_CALL_POD_TYPES(CASE);
272 TF_CALL_QUANTIZED_TYPES(CASE);
273 // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since
274 // they are not supported widely, but are explicitly listed here for
275 // bitcast.
276 TF_CALL_qint16(CASE);
277 TF_CALL_quint16(CASE);
278
279 default:
280 return 0;
281 }
282 #undef CASE
283 }
284
285 // Define DataTypeToEnum<T>::value.
286 #define DEFINE_DATATYPETOENUM_VALUE(TYPE) \
287 constexpr DataType DataTypeToEnum<TYPE>::value;
288
289 DEFINE_DATATYPETOENUM_VALUE(float);
290 DEFINE_DATATYPETOENUM_VALUE(double);
291 DEFINE_DATATYPETOENUM_VALUE(int32);
292 DEFINE_DATATYPETOENUM_VALUE(uint32);
293 DEFINE_DATATYPETOENUM_VALUE(uint16);
294 DEFINE_DATATYPETOENUM_VALUE(uint8);
295 DEFINE_DATATYPETOENUM_VALUE(int16);
296 DEFINE_DATATYPETOENUM_VALUE(int8);
297 DEFINE_DATATYPETOENUM_VALUE(tstring);
298 DEFINE_DATATYPETOENUM_VALUE(complex64);
299 DEFINE_DATATYPETOENUM_VALUE(complex128);
300 DEFINE_DATATYPETOENUM_VALUE(int64_t);
301 DEFINE_DATATYPETOENUM_VALUE(uint64);
302 DEFINE_DATATYPETOENUM_VALUE(bool);
303 DEFINE_DATATYPETOENUM_VALUE(qint8);
304 DEFINE_DATATYPETOENUM_VALUE(quint8);
305 DEFINE_DATATYPETOENUM_VALUE(qint16);
306 DEFINE_DATATYPETOENUM_VALUE(quint16);
307 DEFINE_DATATYPETOENUM_VALUE(qint32);
308 DEFINE_DATATYPETOENUM_VALUE(bfloat16);
309 DEFINE_DATATYPETOENUM_VALUE(Eigen::half);
310 DEFINE_DATATYPETOENUM_VALUE(ResourceHandle);
311 DEFINE_DATATYPETOENUM_VALUE(Variant);
312 #undef DEFINE_DATATYPETOENUM_VALUE
313
314 } // namespace tensorflow
315