1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/types.h"
17 #include "tensorflow/core/framework/register_types.h"
18 #include "tensorflow/core/lib/strings/str_util.h"
19 #include "tensorflow/core/lib/strings/strcat.h"
20 #include "tensorflow/core/platform/logging.h"
21
22 namespace tensorflow {
23
24 struct DataTypeHasher {
operator ()tensorflow::DataTypeHasher25 std::size_t operator()(const DataType& k) const {
26 return std::hash<int>()(static_cast<int>(k));
27 }
28 };
29
30 // Mapping from some of the DType fields, for backward compatibility. All other
31 // dtypes are mapped to TFT_ANY, but can be added here if a counterpart is
32 // defined.
33 auto* DT_TO_FT = new std::unordered_map<DataType, FullTypeId, DataTypeHasher>({
34 {DT_FLOAT, TFT_FLOAT},
35 {DT_DOUBLE, TFT_DOUBLE},
36 {DT_INT32, TFT_INT32},
37 {DT_UINT8, TFT_UINT8},
38 {DT_INT16, TFT_INT16},
39 {DT_INT8, TFT_INT8},
40 {DT_STRING, TFT_STRING},
41 {DT_COMPLEX64, TFT_COMPLEX64},
42 {DT_INT64, TFT_INT64},
43 {DT_BOOL, TFT_BOOL},
44 {DT_UINT16, TFT_UINT16},
45 {DT_COMPLEX128, TFT_COMPLEX128},
46 {DT_HALF, TFT_HALF},
47 {DT_UINT32, TFT_UINT32},
48 {DT_UINT64, TFT_UINT64},
49 });
50
map_dtype_to_tensor(const DataType & dtype,FullTypeDef * t)51 void map_dtype_to_tensor(const DataType& dtype, FullTypeDef* t) {
52 t->set_type_id(TFT_TENSOR);
53 // If the dtype is not mapped, assume it's not supported and use
54 // TFT_ANY, for compatibility. See DT_TO_FT for more details.
55 const auto& mapped = DT_TO_FT->find(dtype);
56 auto* arg = t->add_args();
57 if (mapped != DT_TO_FT->end()) {
58 arg->set_type_id(mapped->second);
59 } else {
60 arg->set_type_id(TFT_ANY);
61 }
62 }
63
operator <(const DeviceType & other) const64 bool DeviceType::operator<(const DeviceType& other) const {
65 return type_ < other.type_;
66 }
67
operator ==(const DeviceType & other) const68 bool DeviceType::operator==(const DeviceType& other) const {
69 return type_ == other.type_;
70 }
71
operator <<(std::ostream & os,const DeviceType & d)72 std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
73 os << d.type();
74 return os;
75 }
76
77 const char* const DEVICE_DEFAULT = "DEFAULT";
78 const char* const DEVICE_CPU = "CPU";
79 const char* const DEVICE_GPU = "GPU";
80 const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
81
82 const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
83 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
84 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
85 const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
86 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
87
88 namespace {
DataTypeStringInternal(DataType dtype)89 string DataTypeStringInternal(DataType dtype) {
90 switch (dtype) {
91 case DT_INVALID:
92 return "INVALID";
93 case DT_FLOAT:
94 return "float";
95 case DT_DOUBLE:
96 return "double";
97 case DT_INT32:
98 return "int32";
99 case DT_UINT32:
100 return "uint32";
101 case DT_UINT8:
102 return "uint8";
103 case DT_UINT16:
104 return "uint16";
105 case DT_INT16:
106 return "int16";
107 case DT_INT8:
108 return "int8";
109 case DT_STRING:
110 return "string";
111 case DT_COMPLEX64:
112 return "complex64";
113 case DT_COMPLEX128:
114 return "complex128";
115 case DT_INT64:
116 return "int64";
117 case DT_UINT64:
118 return "uint64";
119 case DT_BOOL:
120 return "bool";
121 case DT_QINT8:
122 return "qint8";
123 case DT_QUINT8:
124 return "quint8";
125 case DT_QUINT16:
126 return "quint16";
127 case DT_QINT16:
128 return "qint16";
129 case DT_QINT32:
130 return "qint32";
131 case DT_BFLOAT16:
132 return "bfloat16";
133 case DT_HALF:
134 return "half";
135 case DT_RESOURCE:
136 return "resource";
137 case DT_VARIANT:
138 return "variant";
139 default:
140 LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
141 return strings::StrCat("unknown dtype enum (", dtype, ")");
142 }
143 }
144 } // end namespace
145
DataTypeString(DataType dtype)146 string DataTypeString(DataType dtype) {
147 if (IsRefType(dtype)) {
148 DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
149 return strings::StrCat(DataTypeStringInternal(non_ref), "_ref");
150 }
151 return DataTypeStringInternal(dtype);
152 }
153
DataTypeFromString(StringPiece sp,DataType * dt)154 bool DataTypeFromString(StringPiece sp, DataType* dt) {
155 if (str_util::EndsWith(sp, "_ref")) {
156 sp.remove_suffix(4);
157 DataType non_ref;
158 if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
159 *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
160 return true;
161 } else {
162 return false;
163 }
164 }
165
166 if (sp == "float" || sp == "float32") {
167 *dt = DT_FLOAT;
168 return true;
169 } else if (sp == "double" || sp == "float64") {
170 *dt = DT_DOUBLE;
171 return true;
172 } else if (sp == "int32") {
173 *dt = DT_INT32;
174 return true;
175 } else if (sp == "uint32") {
176 *dt = DT_UINT32;
177 return true;
178 } else if (sp == "uint8") {
179 *dt = DT_UINT8;
180 return true;
181 } else if (sp == "uint16") {
182 *dt = DT_UINT16;
183 return true;
184 } else if (sp == "int16") {
185 *dt = DT_INT16;
186 return true;
187 } else if (sp == "int8") {
188 *dt = DT_INT8;
189 return true;
190 } else if (sp == "string") {
191 *dt = DT_STRING;
192 return true;
193 } else if (sp == "complex64") {
194 *dt = DT_COMPLEX64;
195 return true;
196 } else if (sp == "complex128") {
197 *dt = DT_COMPLEX128;
198 return true;
199 } else if (sp == "int64") {
200 *dt = DT_INT64;
201 return true;
202 } else if (sp == "uint64") {
203 *dt = DT_UINT64;
204 return true;
205 } else if (sp == "bool") {
206 *dt = DT_BOOL;
207 return true;
208 } else if (sp == "qint8") {
209 *dt = DT_QINT8;
210 return true;
211 } else if (sp == "quint8") {
212 *dt = DT_QUINT8;
213 return true;
214 } else if (sp == "qint16") {
215 *dt = DT_QINT16;
216 return true;
217 } else if (sp == "quint16") {
218 *dt = DT_QUINT16;
219 return true;
220 } else if (sp == "qint32") {
221 *dt = DT_QINT32;
222 return true;
223 } else if (sp == "bfloat16") {
224 *dt = DT_BFLOAT16;
225 return true;
226 } else if (sp == "half" || sp == "float16") {
227 *dt = DT_HALF;
228 return true;
229 } else if (sp == "resource") {
230 *dt = DT_RESOURCE;
231 return true;
232 } else if (sp == "variant") {
233 *dt = DT_VARIANT;
234 return true;
235 }
236 return false;
237 }
238
DeviceTypeString(const DeviceType & device_type)239 string DeviceTypeString(const DeviceType& device_type) {
240 return device_type.type();
241 }
242
DataTypeSliceString(const DataTypeSlice types)243 string DataTypeSliceString(const DataTypeSlice types) {
244 string out;
245 for (auto it = types.begin(); it != types.end(); ++it) {
246 strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
247 DataTypeString(*it));
248 }
249 return out;
250 }
251
DataTypeAlwaysOnHost(DataType dt)252 bool DataTypeAlwaysOnHost(DataType dt) {
253 // Includes DT_STRING and DT_RESOURCE.
254 switch (dt) {
255 case DT_STRING:
256 case DT_STRING_REF:
257 case DT_RESOURCE:
258 return true;
259 default:
260 return false;
261 }
262 }
263
DataTypeSize(DataType dt)264 int DataTypeSize(DataType dt) {
265 #define CASE(T) \
266 case DataTypeToEnum<T>::value: \
267 return sizeof(T);
268 switch (dt) {
269 TF_CALL_POD_TYPES(CASE);
270 TF_CALL_QUANTIZED_TYPES(CASE);
271 // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since
272 // they are not supported widely, but are explicitly listed here for
273 // bitcast.
274 TF_CALL_qint16(CASE);
275 TF_CALL_quint16(CASE);
276
277 default:
278 return 0;
279 }
280 #undef CASE
281 }
282
283 // Define DataTypeToEnum<T>::value.
284 #define DEFINE_DATATYPETOENUM_VALUE(TYPE) \
285 constexpr DataType DataTypeToEnum<TYPE>::value;
286
287 DEFINE_DATATYPETOENUM_VALUE(float);
288 DEFINE_DATATYPETOENUM_VALUE(double);
289 DEFINE_DATATYPETOENUM_VALUE(int32);
290 DEFINE_DATATYPETOENUM_VALUE(uint32);
291 DEFINE_DATATYPETOENUM_VALUE(uint16);
292 DEFINE_DATATYPETOENUM_VALUE(uint8);
293 DEFINE_DATATYPETOENUM_VALUE(int16);
294 DEFINE_DATATYPETOENUM_VALUE(int8);
295 DEFINE_DATATYPETOENUM_VALUE(tstring);
296 DEFINE_DATATYPETOENUM_VALUE(complex64);
297 DEFINE_DATATYPETOENUM_VALUE(complex128);
298 DEFINE_DATATYPETOENUM_VALUE(int64);
299 DEFINE_DATATYPETOENUM_VALUE(uint64);
300 DEFINE_DATATYPETOENUM_VALUE(bool);
301 DEFINE_DATATYPETOENUM_VALUE(qint8);
302 DEFINE_DATATYPETOENUM_VALUE(quint8);
303 DEFINE_DATATYPETOENUM_VALUE(qint16);
304 DEFINE_DATATYPETOENUM_VALUE(quint16);
305 DEFINE_DATATYPETOENUM_VALUE(qint32);
306 DEFINE_DATATYPETOENUM_VALUE(bfloat16);
307 DEFINE_DATATYPETOENUM_VALUE(Eigen::half);
308 DEFINE_DATATYPETOENUM_VALUE(ResourceHandle);
309 DEFINE_DATATYPETOENUM_VALUE(Variant);
310 #undef DEFINE_DATATYPETOENUM_VALUE
311
312 } // namespace tensorflow
313