• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/types.h"
17 #include "tensorflow/core/framework/register_types.h"
18 
19 #include "tensorflow/core/lib/strings/str_util.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 #include "tensorflow/core/platform/logging.h"
22 
23 namespace tensorflow {
24 
operator <(const DeviceType & other) const25 bool DeviceType::operator<(const DeviceType& other) const {
26   return type_ < other.type_;
27 }
28 
operator ==(const DeviceType & other) const29 bool DeviceType::operator==(const DeviceType& other) const {
30   return type_ == other.type_;
31 }
32 
operator <<(std::ostream & os,const DeviceType & d)33 std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
34   os << d.type();
35   return os;
36 }
37 
38 const char* const DEVICE_DEFAULT = "DEFAULT";
39 const char* const DEVICE_CPU = "CPU";
40 const char* const DEVICE_GPU = "GPU";
41 const char* const DEVICE_SYCL = "SYCL";
42 
43 const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
44 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
45     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
46 const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
47 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48 #ifdef TENSORFLOW_USE_SYCL
49 const std::string DeviceName<Eigen::SyclDevice>::value = DEVICE_SYCL;
50 #endif  // TENSORFLOW_USE_SYCL
51 
52 namespace {
DataTypeStringInternal(DataType dtype)53 string DataTypeStringInternal(DataType dtype) {
54   switch (dtype) {
55     case DT_INVALID:
56       return "INVALID";
57     case DT_FLOAT:
58       return "float";
59     case DT_DOUBLE:
60       return "double";
61     case DT_INT32:
62       return "int32";
63     case DT_UINT32:
64       return "uint32";
65     case DT_UINT8:
66       return "uint8";
67     case DT_UINT16:
68       return "uint16";
69     case DT_INT16:
70       return "int16";
71     case DT_INT8:
72       return "int8";
73     case DT_STRING:
74       return "string";
75     case DT_COMPLEX64:
76       return "complex64";
77     case DT_COMPLEX128:
78       return "complex128";
79     case DT_INT64:
80       return "int64";
81     case DT_UINT64:
82       return "uint64";
83     case DT_BOOL:
84       return "bool";
85     case DT_QINT8:
86       return "qint8";
87     case DT_QUINT8:
88       return "quint8";
89     case DT_QUINT16:
90       return "quint16";
91     case DT_QINT16:
92       return "qint16";
93     case DT_QINT32:
94       return "qint32";
95     case DT_BFLOAT16:
96       return "bfloat16";
97     case DT_HALF:
98       return "half";
99     case DT_RESOURCE:
100       return "resource";
101     case DT_VARIANT:
102       return "variant";
103     default:
104       LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
105       return strings::StrCat("unknown dtype enum (", dtype, ")");
106   }
107 }
108 }  // end namespace
109 
DataTypeString(DataType dtype)110 string DataTypeString(DataType dtype) {
111   if (IsRefType(dtype)) {
112     DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
113     return strings::StrCat(DataTypeStringInternal(non_ref), "_ref");
114   }
115   return DataTypeStringInternal(dtype);
116 }
117 
DataTypeFromString(StringPiece sp,DataType * dt)118 bool DataTypeFromString(StringPiece sp, DataType* dt) {
119   if (str_util::EndsWith(sp, "_ref")) {
120     sp.remove_suffix(4);
121     DataType non_ref;
122     if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
123       *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
124       return true;
125     } else {
126       return false;
127     }
128   }
129 
130   if (sp == "float" || sp == "float32") {
131     *dt = DT_FLOAT;
132     return true;
133   } else if (sp == "double" || sp == "float64") {
134     *dt = DT_DOUBLE;
135     return true;
136   } else if (sp == "int32") {
137     *dt = DT_INT32;
138     return true;
139   } else if (sp == "uint32") {
140     *dt = DT_UINT32;
141     return true;
142   } else if (sp == "uint8") {
143     *dt = DT_UINT8;
144     return true;
145   } else if (sp == "uint16") {
146     *dt = DT_UINT16;
147     return true;
148   } else if (sp == "int16") {
149     *dt = DT_INT16;
150     return true;
151   } else if (sp == "int8") {
152     *dt = DT_INT8;
153     return true;
154   } else if (sp == "string") {
155     *dt = DT_STRING;
156     return true;
157   } else if (sp == "complex64") {
158     *dt = DT_COMPLEX64;
159     return true;
160   } else if (sp == "complex128") {
161     *dt = DT_COMPLEX128;
162     return true;
163   } else if (sp == "int64") {
164     *dt = DT_INT64;
165     return true;
166   } else if (sp == "uint64") {
167     *dt = DT_UINT64;
168     return true;
169   } else if (sp == "bool") {
170     *dt = DT_BOOL;
171     return true;
172   } else if (sp == "qint8") {
173     *dt = DT_QINT8;
174     return true;
175   } else if (sp == "quint8") {
176     *dt = DT_QUINT8;
177     return true;
178   } else if (sp == "qint16") {
179     *dt = DT_QINT16;
180     return true;
181   } else if (sp == "quint16") {
182     *dt = DT_QUINT16;
183     return true;
184   } else if (sp == "qint32") {
185     *dt = DT_QINT32;
186     return true;
187   } else if (sp == "bfloat16") {
188     *dt = DT_BFLOAT16;
189     return true;
190   } else if (sp == "half" || sp == "float16") {
191     *dt = DT_HALF;
192     return true;
193   } else if (sp == "resource") {
194     *dt = DT_RESOURCE;
195     return true;
196   } else if (sp == "variant") {
197     *dt = DT_VARIANT;
198     return true;
199   }
200   return false;
201 }
202 
DeviceTypeString(const DeviceType & device_type)203 string DeviceTypeString(const DeviceType& device_type) {
204   return device_type.type();
205 }
206 
DataTypeSliceString(const DataTypeSlice types)207 string DataTypeSliceString(const DataTypeSlice types) {
208   string out;
209   for (auto it = types.begin(); it != types.end(); ++it) {
210     strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
211                        DataTypeString(*it));
212   }
213   return out;
214 }
215 
DataTypeAlwaysOnHost(DataType dt)216 bool DataTypeAlwaysOnHost(DataType dt) {
217   // Includes DT_STRING and DT_RESOURCE.
218   switch (dt) {
219     case DT_STRING:
220     case DT_STRING_REF:
221     case DT_RESOURCE:
222       return true;
223     default:
224       return false;
225   }
226 }
227 
DataTypeSize(DataType dt)228 int DataTypeSize(DataType dt) {
229 #define CASE(T)                  \
230   case DataTypeToEnum<T>::value: \
231     return sizeof(T);
232   switch (dt) {
233     TF_CALL_POD_TYPES(CASE);
234     TF_CALL_QUANTIZED_TYPES(CASE);
235     // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since
236     // they are not supported widely, but are explicitly listed here for
237     // bitcast.
238     TF_CALL_qint16(CASE);
239     TF_CALL_quint16(CASE);
240 
241     // uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we
242     // don't want to define kernels for them at this stage to avoid binary
243     // bloat.
244     TF_CALL_uint32(CASE);
245     TF_CALL_uint64(CASE);
246     default:
247       return 0;
248   }
249 #undef CASE
250 }
251 
252 // Define DataTypeToEnum<T>::value.
253 #define DEFINE_DATATYPETOENUM_VALUE(TYPE) \
254   constexpr DataType DataTypeToEnum<TYPE>::value;
255 
256 DEFINE_DATATYPETOENUM_VALUE(float);
257 DEFINE_DATATYPETOENUM_VALUE(double);
258 DEFINE_DATATYPETOENUM_VALUE(int32);
259 DEFINE_DATATYPETOENUM_VALUE(uint32);
260 DEFINE_DATATYPETOENUM_VALUE(uint16);
261 DEFINE_DATATYPETOENUM_VALUE(uint8);
262 DEFINE_DATATYPETOENUM_VALUE(int16);
263 DEFINE_DATATYPETOENUM_VALUE(int8);
264 DEFINE_DATATYPETOENUM_VALUE(tstring);
265 DEFINE_DATATYPETOENUM_VALUE(complex64);
266 DEFINE_DATATYPETOENUM_VALUE(complex128);
267 DEFINE_DATATYPETOENUM_VALUE(int64);
268 DEFINE_DATATYPETOENUM_VALUE(uint64);
269 DEFINE_DATATYPETOENUM_VALUE(bool);
270 DEFINE_DATATYPETOENUM_VALUE(qint8);
271 DEFINE_DATATYPETOENUM_VALUE(quint8);
272 DEFINE_DATATYPETOENUM_VALUE(qint16);
273 DEFINE_DATATYPETOENUM_VALUE(quint16);
274 DEFINE_DATATYPETOENUM_VALUE(qint32);
275 DEFINE_DATATYPETOENUM_VALUE(bfloat16);
276 DEFINE_DATATYPETOENUM_VALUE(Eigen::half);
277 DEFINE_DATATYPETOENUM_VALUE(ResourceHandle);
278 DEFINE_DATATYPETOENUM_VALUE(Variant);
279 #undef DEFINE_DATATYPETOENUM_VALUE
280 
281 }  // namespace tensorflow
282