1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/types.h"
17 #include "tensorflow/core/framework/register_types.h"
18
19 #include "tensorflow/core/lib/strings/str_util.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 #include "tensorflow/core/platform/logging.h"
22
23 namespace tensorflow {
24
operator <(const DeviceType & other) const25 bool DeviceType::operator<(const DeviceType& other) const {
26 return type_ < other.type_;
27 }
28
operator ==(const DeviceType & other) const29 bool DeviceType::operator==(const DeviceType& other) const {
30 return type_ == other.type_;
31 }
32
operator <<(std::ostream & os,const DeviceType & d)33 std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
34 os << d.type();
35 return os;
36 }
37
38 const char* const DEVICE_DEFAULT = "DEFAULT";
39 const char* const DEVICE_CPU = "CPU";
40 const char* const DEVICE_GPU = "GPU";
41 const char* const DEVICE_SYCL = "SYCL";
42
43 const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
44 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
45 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
46 const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
47 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48 #ifdef TENSORFLOW_USE_SYCL
49 const std::string DeviceName<Eigen::SyclDevice>::value = DEVICE_SYCL;
50 #endif // TENSORFLOW_USE_SYCL
51
52 namespace {
DataTypeStringInternal(DataType dtype)53 string DataTypeStringInternal(DataType dtype) {
54 switch (dtype) {
55 case DT_INVALID:
56 return "INVALID";
57 case DT_FLOAT:
58 return "float";
59 case DT_DOUBLE:
60 return "double";
61 case DT_INT32:
62 return "int32";
63 case DT_UINT32:
64 return "uint32";
65 case DT_UINT8:
66 return "uint8";
67 case DT_UINT16:
68 return "uint16";
69 case DT_INT16:
70 return "int16";
71 case DT_INT8:
72 return "int8";
73 case DT_STRING:
74 return "string";
75 case DT_COMPLEX64:
76 return "complex64";
77 case DT_COMPLEX128:
78 return "complex128";
79 case DT_INT64:
80 return "int64";
81 case DT_UINT64:
82 return "uint64";
83 case DT_BOOL:
84 return "bool";
85 case DT_QINT8:
86 return "qint8";
87 case DT_QUINT8:
88 return "quint8";
89 case DT_QUINT16:
90 return "quint16";
91 case DT_QINT16:
92 return "qint16";
93 case DT_QINT32:
94 return "qint32";
95 case DT_BFLOAT16:
96 return "bfloat16";
97 case DT_HALF:
98 return "half";
99 case DT_RESOURCE:
100 return "resource";
101 case DT_VARIANT:
102 return "variant";
103 default:
104 LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
105 return strings::StrCat("unknown dtype enum (", dtype, ")");
106 }
107 }
108 } // end namespace
109
DataTypeString(DataType dtype)110 string DataTypeString(DataType dtype) {
111 if (IsRefType(dtype)) {
112 DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
113 return strings::StrCat(DataTypeStringInternal(non_ref), "_ref");
114 }
115 return DataTypeStringInternal(dtype);
116 }
117
DataTypeFromString(StringPiece sp,DataType * dt)118 bool DataTypeFromString(StringPiece sp, DataType* dt) {
119 if (str_util::EndsWith(sp, "_ref")) {
120 sp.remove_suffix(4);
121 DataType non_ref;
122 if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
123 *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
124 return true;
125 } else {
126 return false;
127 }
128 }
129
130 if (sp == "float" || sp == "float32") {
131 *dt = DT_FLOAT;
132 return true;
133 } else if (sp == "double" || sp == "float64") {
134 *dt = DT_DOUBLE;
135 return true;
136 } else if (sp == "int32") {
137 *dt = DT_INT32;
138 return true;
139 } else if (sp == "uint32") {
140 *dt = DT_UINT32;
141 return true;
142 } else if (sp == "uint8") {
143 *dt = DT_UINT8;
144 return true;
145 } else if (sp == "uint16") {
146 *dt = DT_UINT16;
147 return true;
148 } else if (sp == "int16") {
149 *dt = DT_INT16;
150 return true;
151 } else if (sp == "int8") {
152 *dt = DT_INT8;
153 return true;
154 } else if (sp == "string") {
155 *dt = DT_STRING;
156 return true;
157 } else if (sp == "complex64") {
158 *dt = DT_COMPLEX64;
159 return true;
160 } else if (sp == "complex128") {
161 *dt = DT_COMPLEX128;
162 return true;
163 } else if (sp == "int64") {
164 *dt = DT_INT64;
165 return true;
166 } else if (sp == "uint64") {
167 *dt = DT_UINT64;
168 return true;
169 } else if (sp == "bool") {
170 *dt = DT_BOOL;
171 return true;
172 } else if (sp == "qint8") {
173 *dt = DT_QINT8;
174 return true;
175 } else if (sp == "quint8") {
176 *dt = DT_QUINT8;
177 return true;
178 } else if (sp == "qint16") {
179 *dt = DT_QINT16;
180 return true;
181 } else if (sp == "quint16") {
182 *dt = DT_QUINT16;
183 return true;
184 } else if (sp == "qint32") {
185 *dt = DT_QINT32;
186 return true;
187 } else if (sp == "bfloat16") {
188 *dt = DT_BFLOAT16;
189 return true;
190 } else if (sp == "half" || sp == "float16") {
191 *dt = DT_HALF;
192 return true;
193 } else if (sp == "resource") {
194 *dt = DT_RESOURCE;
195 return true;
196 } else if (sp == "variant") {
197 *dt = DT_VARIANT;
198 return true;
199 }
200 return false;
201 }
202
DeviceTypeString(const DeviceType & device_type)203 string DeviceTypeString(const DeviceType& device_type) {
204 return device_type.type();
205 }
206
DataTypeSliceString(const DataTypeSlice types)207 string DataTypeSliceString(const DataTypeSlice types) {
208 string out;
209 for (auto it = types.begin(); it != types.end(); ++it) {
210 strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
211 DataTypeString(*it));
212 }
213 return out;
214 }
215
DataTypeAlwaysOnHost(DataType dt)216 bool DataTypeAlwaysOnHost(DataType dt) {
217 // Includes DT_STRING and DT_RESOURCE.
218 switch (dt) {
219 case DT_STRING:
220 case DT_STRING_REF:
221 case DT_RESOURCE:
222 return true;
223 default:
224 return false;
225 }
226 }
227
DataTypeSize(DataType dt)228 int DataTypeSize(DataType dt) {
229 #define CASE(T) \
230 case DataTypeToEnum<T>::value: \
231 return sizeof(T);
232 switch (dt) {
233 TF_CALL_POD_TYPES(CASE);
234 TF_CALL_QUANTIZED_TYPES(CASE);
235 // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since
236 // they are not supported widely, but are explicitly listed here for
237 // bitcast.
238 TF_CALL_qint16(CASE);
239 TF_CALL_quint16(CASE);
240
241 // uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we
242 // don't want to define kernels for them at this stage to avoid binary
243 // bloat.
244 TF_CALL_uint32(CASE);
245 TF_CALL_uint64(CASE);
246 default:
247 return 0;
248 }
249 #undef CASE
250 }
251
252 // Define DataTypeToEnum<T>::value.
253 #define DEFINE_DATATYPETOENUM_VALUE(TYPE) \
254 constexpr DataType DataTypeToEnum<TYPE>::value;
255
256 DEFINE_DATATYPETOENUM_VALUE(float);
257 DEFINE_DATATYPETOENUM_VALUE(double);
258 DEFINE_DATATYPETOENUM_VALUE(int32);
259 DEFINE_DATATYPETOENUM_VALUE(uint32);
260 DEFINE_DATATYPETOENUM_VALUE(uint16);
261 DEFINE_DATATYPETOENUM_VALUE(uint8);
262 DEFINE_DATATYPETOENUM_VALUE(int16);
263 DEFINE_DATATYPETOENUM_VALUE(int8);
264 DEFINE_DATATYPETOENUM_VALUE(tstring);
265 DEFINE_DATATYPETOENUM_VALUE(complex64);
266 DEFINE_DATATYPETOENUM_VALUE(complex128);
267 DEFINE_DATATYPETOENUM_VALUE(int64);
268 DEFINE_DATATYPETOENUM_VALUE(uint64);
269 DEFINE_DATATYPETOENUM_VALUE(bool);
270 DEFINE_DATATYPETOENUM_VALUE(qint8);
271 DEFINE_DATATYPETOENUM_VALUE(quint8);
272 DEFINE_DATATYPETOENUM_VALUE(qint16);
273 DEFINE_DATATYPETOENUM_VALUE(quint16);
274 DEFINE_DATATYPETOENUM_VALUE(qint32);
275 DEFINE_DATATYPETOENUM_VALUE(bfloat16);
276 DEFINE_DATATYPETOENUM_VALUE(Eigen::half);
277 DEFINE_DATATYPETOENUM_VALUE(ResourceHandle);
278 DEFINE_DATATYPETOENUM_VALUE(Variant);
279 #undef DEFINE_DATATYPETOENUM_VALUE
280
281 } // namespace tensorflow
282