1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/types.h"
17 #include "tensorflow/core/framework/register_types.h"
18
19 #include "tensorflow/core/lib/strings/str_util.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 #include "tensorflow/core/platform/logging.h"
22
23 namespace tensorflow {
24
operator <(const DeviceType & other) const25 bool DeviceType::operator<(const DeviceType& other) const {
26 return type_ < other.type_;
27 }
28
operator ==(const DeviceType & other) const29 bool DeviceType::operator==(const DeviceType& other) const {
30 return type_ == other.type_;
31 }
32
operator <<(std::ostream & os,const DeviceType & d)33 std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
34 os << d.type();
35 return os;
36 }
37
38 const char* const DEVICE_CPU = "CPU";
39 const char* const DEVICE_GPU = "GPU";
40 const char* const DEVICE_SYCL = "SYCL";
41
42 const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
43 #if GOOGLE_CUDA
44 const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
45 #endif // GOOGLE_CUDA
46 #ifdef TENSORFLOW_USE_SYCL
47 const std::string DeviceName<Eigen::SyclDevice>::value = DEVICE_SYCL;
48 #endif // TENSORFLOW_USE_SYCL
49
50 namespace {
DataTypeStringInternal(DataType dtype)51 string DataTypeStringInternal(DataType dtype) {
52 switch (dtype) {
53 case DT_INVALID:
54 return "INVALID";
55 case DT_FLOAT:
56 return "float";
57 case DT_DOUBLE:
58 return "double";
59 case DT_INT32:
60 return "int32";
61 case DT_UINT32:
62 return "uint32";
63 case DT_UINT8:
64 return "uint8";
65 case DT_UINT16:
66 return "uint16";
67 case DT_INT16:
68 return "int16";
69 case DT_INT8:
70 return "int8";
71 case DT_STRING:
72 return "string";
73 case DT_COMPLEX64:
74 return "complex64";
75 case DT_COMPLEX128:
76 return "complex128";
77 case DT_INT64:
78 return "int64";
79 case DT_UINT64:
80 return "uint64";
81 case DT_BOOL:
82 return "bool";
83 case DT_QINT8:
84 return "qint8";
85 case DT_QUINT8:
86 return "quint8";
87 case DT_QUINT16:
88 return "quint16";
89 case DT_QINT16:
90 return "qint16";
91 case DT_QINT32:
92 return "qint32";
93 case DT_BFLOAT16:
94 return "bfloat16";
95 case DT_HALF:
96 return "half";
97 case DT_RESOURCE:
98 return "resource";
99 case DT_VARIANT:
100 return "variant";
101 default:
102 LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
103 return strings::StrCat("unknown dtype enum (", dtype, ")");
104 }
105 }
106 } // end namespace
107
DataTypeString(DataType dtype)108 string DataTypeString(DataType dtype) {
109 if (IsRefType(dtype)) {
110 DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
111 return strings::StrCat(DataTypeStringInternal(non_ref), "_ref");
112 }
113 return DataTypeStringInternal(dtype);
114 }
115
DataTypeFromString(StringPiece sp,DataType * dt)116 bool DataTypeFromString(StringPiece sp, DataType* dt) {
117 if (sp.ends_with("_ref")) {
118 sp.remove_suffix(4);
119 DataType non_ref;
120 if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
121 *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
122 return true;
123 } else {
124 return false;
125 }
126 }
127
128 if (sp == "float" || sp == "float32") {
129 *dt = DT_FLOAT;
130 return true;
131 } else if (sp == "double" || sp == "float64") {
132 *dt = DT_DOUBLE;
133 return true;
134 } else if (sp == "int32") {
135 *dt = DT_INT32;
136 return true;
137 } else if (sp == "uint32") {
138 *dt = DT_UINT32;
139 return true;
140 } else if (sp == "uint8") {
141 *dt = DT_UINT8;
142 return true;
143 } else if (sp == "uint16") {
144 *dt = DT_UINT16;
145 return true;
146 } else if (sp == "int16") {
147 *dt = DT_INT16;
148 return true;
149 } else if (sp == "int8") {
150 *dt = DT_INT8;
151 return true;
152 } else if (sp == "string") {
153 *dt = DT_STRING;
154 return true;
155 } else if (sp == "complex64") {
156 *dt = DT_COMPLEX64;
157 return true;
158 } else if (sp == "complex128") {
159 *dt = DT_COMPLEX128;
160 return true;
161 } else if (sp == "int64") {
162 *dt = DT_INT64;
163 return true;
164 } else if (sp == "uint64") {
165 *dt = DT_UINT64;
166 return true;
167 } else if (sp == "bool") {
168 *dt = DT_BOOL;
169 return true;
170 } else if (sp == "qint8") {
171 *dt = DT_QINT8;
172 return true;
173 } else if (sp == "quint8") {
174 *dt = DT_QUINT8;
175 return true;
176 } else if (sp == "qint16") {
177 *dt = DT_QINT16;
178 return true;
179 } else if (sp == "quint16") {
180 *dt = DT_QUINT16;
181 return true;
182 } else if (sp == "qint32") {
183 *dt = DT_QINT32;
184 return true;
185 } else if (sp == "bfloat16") {
186 *dt = DT_BFLOAT16;
187 return true;
188 } else if (sp == "half" || sp == "float16") {
189 *dt = DT_HALF;
190 return true;
191 } else if (sp == "resource") {
192 *dt = DT_RESOURCE;
193 return true;
194 } else if (sp == "variant") {
195 *dt = DT_VARIANT;
196 return true;
197 }
198 return false;
199 }
200
DeviceTypeString(const DeviceType & device_type)201 string DeviceTypeString(const DeviceType& device_type) {
202 return device_type.type();
203 }
204
DataTypeSliceString(const DataTypeSlice types)205 string DataTypeSliceString(const DataTypeSlice types) {
206 string out;
207 for (auto it = types.begin(); it != types.end(); ++it) {
208 strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
209 DataTypeString(*it));
210 }
211 return out;
212 }
213
DataTypeAlwaysOnHost(DataType dt)214 bool DataTypeAlwaysOnHost(DataType dt) {
215 // Includes DT_STRING and DT_RESOURCE.
216 switch (dt) {
217 case DT_STRING:
218 case DT_STRING_REF:
219 case DT_RESOURCE:
220 return true;
221 default:
222 return false;
223 }
224 }
225
DataTypeSize(DataType dt)226 int DataTypeSize(DataType dt) {
227 #define CASE(T) \
228 case DataTypeToEnum<T>::value: \
229 return sizeof(T);
230 switch (dt) {
231 TF_CALL_POD_TYPES(CASE);
232 TF_CALL_QUANTIZED_TYPES(CASE);
233 // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since
234 // they are not supported widely, but are explicitly listed here for
235 // bitcast.
236 TF_CALL_qint16(CASE);
237 TF_CALL_quint16(CASE);
238
239 // uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we
240 // don't want to define kernels for them at this stage to avoid binary
241 // bloat.
242 TF_CALL_uint32(CASE);
243 TF_CALL_uint64(CASE);
244 default:
245 return 0;
246 }
247 #undef CASE
248 }
249
250 } // namespace tensorflow
251