1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/types.h"
17 #include "tensorflow/core/framework/register_types.h"
18
19 #include "tensorflow/core/lib/strings/str_util.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 #include "tensorflow/core/platform/logging.h"
22
23 namespace tensorflow {
24
operator <(const DeviceType & other) const25 bool DeviceType::operator<(const DeviceType& other) const {
26 return type_ < other.type_;
27 }
28
operator ==(const DeviceType & other) const29 bool DeviceType::operator==(const DeviceType& other) const {
30 return type_ == other.type_;
31 }
32
operator <<(std::ostream & os,const DeviceType & d)33 std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
34 os << d.type();
35 return os;
36 }
37
38 const char* const DEVICE_DEFAULT = "DEFAULT";
39 const char* const DEVICE_CPU = "CPU";
40 const char* const DEVICE_GPU = "GPU";
41 const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
42
43 const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
44 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
45 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
46 const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
47 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48
49 namespace {
DataTypeStringInternal(DataType dtype)50 string DataTypeStringInternal(DataType dtype) {
51 switch (dtype) {
52 case DT_INVALID:
53 return "INVALID";
54 case DT_FLOAT:
55 return "float";
56 case DT_DOUBLE:
57 return "double";
58 case DT_INT32:
59 return "int32";
60 case DT_UINT32:
61 return "uint32";
62 case DT_UINT8:
63 return "uint8";
64 case DT_UINT16:
65 return "uint16";
66 case DT_INT16:
67 return "int16";
68 case DT_INT8:
69 return "int8";
70 case DT_STRING:
71 return "string";
72 case DT_COMPLEX64:
73 return "complex64";
74 case DT_COMPLEX128:
75 return "complex128";
76 case DT_INT64:
77 return "int64";
78 case DT_UINT64:
79 return "uint64";
80 case DT_BOOL:
81 return "bool";
82 case DT_QINT8:
83 return "qint8";
84 case DT_QUINT8:
85 return "quint8";
86 case DT_QUINT16:
87 return "quint16";
88 case DT_QINT16:
89 return "qint16";
90 case DT_QINT32:
91 return "qint32";
92 case DT_BFLOAT16:
93 return "bfloat16";
94 case DT_HALF:
95 return "half";
96 case DT_RESOURCE:
97 return "resource";
98 case DT_VARIANT:
99 return "variant";
100 default:
101 LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
102 return strings::StrCat("unknown dtype enum (", dtype, ")");
103 }
104 }
105 } // end namespace
106
DataTypeString(DataType dtype)107 string DataTypeString(DataType dtype) {
108 if (IsRefType(dtype)) {
109 DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
110 return strings::StrCat(DataTypeStringInternal(non_ref), "_ref");
111 }
112 return DataTypeStringInternal(dtype);
113 }
114
DataTypeFromString(StringPiece sp,DataType * dt)115 bool DataTypeFromString(StringPiece sp, DataType* dt) {
116 if (str_util::EndsWith(sp, "_ref")) {
117 sp.remove_suffix(4);
118 DataType non_ref;
119 if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
120 *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
121 return true;
122 } else {
123 return false;
124 }
125 }
126
127 if (sp == "float" || sp == "float32") {
128 *dt = DT_FLOAT;
129 return true;
130 } else if (sp == "double" || sp == "float64") {
131 *dt = DT_DOUBLE;
132 return true;
133 } else if (sp == "int32") {
134 *dt = DT_INT32;
135 return true;
136 } else if (sp == "uint32") {
137 *dt = DT_UINT32;
138 return true;
139 } else if (sp == "uint8") {
140 *dt = DT_UINT8;
141 return true;
142 } else if (sp == "uint16") {
143 *dt = DT_UINT16;
144 return true;
145 } else if (sp == "int16") {
146 *dt = DT_INT16;
147 return true;
148 } else if (sp == "int8") {
149 *dt = DT_INT8;
150 return true;
151 } else if (sp == "string") {
152 *dt = DT_STRING;
153 return true;
154 } else if (sp == "complex64") {
155 *dt = DT_COMPLEX64;
156 return true;
157 } else if (sp == "complex128") {
158 *dt = DT_COMPLEX128;
159 return true;
160 } else if (sp == "int64") {
161 *dt = DT_INT64;
162 return true;
163 } else if (sp == "uint64") {
164 *dt = DT_UINT64;
165 return true;
166 } else if (sp == "bool") {
167 *dt = DT_BOOL;
168 return true;
169 } else if (sp == "qint8") {
170 *dt = DT_QINT8;
171 return true;
172 } else if (sp == "quint8") {
173 *dt = DT_QUINT8;
174 return true;
175 } else if (sp == "qint16") {
176 *dt = DT_QINT16;
177 return true;
178 } else if (sp == "quint16") {
179 *dt = DT_QUINT16;
180 return true;
181 } else if (sp == "qint32") {
182 *dt = DT_QINT32;
183 return true;
184 } else if (sp == "bfloat16") {
185 *dt = DT_BFLOAT16;
186 return true;
187 } else if (sp == "half" || sp == "float16") {
188 *dt = DT_HALF;
189 return true;
190 } else if (sp == "resource") {
191 *dt = DT_RESOURCE;
192 return true;
193 } else if (sp == "variant") {
194 *dt = DT_VARIANT;
195 return true;
196 }
197 return false;
198 }
199
DeviceTypeString(const DeviceType & device_type)200 string DeviceTypeString(const DeviceType& device_type) {
201 return device_type.type();
202 }
203
DataTypeSliceString(const DataTypeSlice types)204 string DataTypeSliceString(const DataTypeSlice types) {
205 string out;
206 for (auto it = types.begin(); it != types.end(); ++it) {
207 strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
208 DataTypeString(*it));
209 }
210 return out;
211 }
212
DataTypeAlwaysOnHost(DataType dt)213 bool DataTypeAlwaysOnHost(DataType dt) {
214 // Includes DT_STRING and DT_RESOURCE.
215 switch (dt) {
216 case DT_STRING:
217 case DT_STRING_REF:
218 case DT_RESOURCE:
219 return true;
220 default:
221 return false;
222 }
223 }
224
DataTypeSize(DataType dt)225 int DataTypeSize(DataType dt) {
226 #define CASE(T) \
227 case DataTypeToEnum<T>::value: \
228 return sizeof(T);
229 switch (dt) {
230 TF_CALL_POD_TYPES(CASE);
231 TF_CALL_QUANTIZED_TYPES(CASE);
232 // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since
233 // they are not supported widely, but are explicitly listed here for
234 // bitcast.
235 TF_CALL_qint16(CASE);
236 TF_CALL_quint16(CASE);
237
238 default:
239 return 0;
240 }
241 #undef CASE
242 }
243
244 // Define DataTypeToEnum<T>::value.
245 #define DEFINE_DATATYPETOENUM_VALUE(TYPE) \
246 constexpr DataType DataTypeToEnum<TYPE>::value;
247
248 DEFINE_DATATYPETOENUM_VALUE(float);
249 DEFINE_DATATYPETOENUM_VALUE(double);
250 DEFINE_DATATYPETOENUM_VALUE(int32);
251 DEFINE_DATATYPETOENUM_VALUE(uint32);
252 DEFINE_DATATYPETOENUM_VALUE(uint16);
253 DEFINE_DATATYPETOENUM_VALUE(uint8);
254 DEFINE_DATATYPETOENUM_VALUE(int16);
255 DEFINE_DATATYPETOENUM_VALUE(int8);
256 DEFINE_DATATYPETOENUM_VALUE(tstring);
257 DEFINE_DATATYPETOENUM_VALUE(complex64);
258 DEFINE_DATATYPETOENUM_VALUE(complex128);
259 DEFINE_DATATYPETOENUM_VALUE(int64);
260 DEFINE_DATATYPETOENUM_VALUE(uint64);
261 DEFINE_DATATYPETOENUM_VALUE(bool);
262 DEFINE_DATATYPETOENUM_VALUE(qint8);
263 DEFINE_DATATYPETOENUM_VALUE(quint8);
264 DEFINE_DATATYPETOENUM_VALUE(qint16);
265 DEFINE_DATATYPETOENUM_VALUE(quint16);
266 DEFINE_DATATYPETOENUM_VALUE(qint32);
267 DEFINE_DATATYPETOENUM_VALUE(bfloat16);
268 DEFINE_DATATYPETOENUM_VALUE(Eigen::half);
269 DEFINE_DATATYPETOENUM_VALUE(ResourceHandle);
270 DEFINE_DATATYPETOENUM_VALUE(Variant);
271 #undef DEFINE_DATATYPETOENUM_VALUE
272
273 } // namespace tensorflow
274