• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/types.h"
17 #include "tensorflow/core/framework/register_types.h"
18 #include "tensorflow/core/lib/strings/str_util.h"
19 #include "tensorflow/core/lib/strings/strcat.h"
20 #include "tensorflow/core/platform/logging.h"
21 
22 namespace tensorflow {
23 
24 struct DataTypeHasher {
operator ()tensorflow::DataTypeHasher25   std::size_t operator()(const DataType& k) const {
26     return std::hash<int>()(static_cast<int>(k));
27   }
28 };
29 
30 // Mapping from some of the DType fields, for backward compatibility. All other
31 // dtypes are mapped to TFT_ANY, but can be added here if a counterpart is
32 // defined.
33 auto* DT_TO_FT = new std::unordered_map<DataType, FullTypeId, DataTypeHasher>({
34     {DT_FLOAT, TFT_FLOAT},
35     {DT_DOUBLE, TFT_DOUBLE},
36     {DT_INT32, TFT_INT32},
37     {DT_UINT8, TFT_UINT8},
38     {DT_INT16, TFT_INT16},
39     {DT_INT8, TFT_INT8},
40     {DT_STRING, TFT_STRING},
41     {DT_COMPLEX64, TFT_COMPLEX64},
42     {DT_INT64, TFT_INT64},
43     {DT_BOOL, TFT_BOOL},
44     {DT_UINT16, TFT_UINT16},
45     {DT_COMPLEX128, TFT_COMPLEX128},
46     {DT_HALF, TFT_HALF},
47     {DT_UINT32, TFT_UINT32},
48     {DT_UINT64, TFT_UINT64},
49 });
50 
map_dtype_to_tensor(const DataType & dtype,FullTypeDef * t)51 void map_dtype_to_tensor(const DataType& dtype, FullTypeDef* t) {
52   t->set_type_id(TFT_TENSOR);
53   // If the dtype is not mapped, assume it's not supported and use
54   // TFT_ANY, for compatibility. See DT_TO_FT for more details.
55   const auto& mapped = DT_TO_FT->find(dtype);
56   auto* arg = t->add_args();
57   if (mapped != DT_TO_FT->end()) {
58     arg->set_type_id(mapped->second);
59   } else {
60     arg->set_type_id(TFT_ANY);
61   }
62 }
63 
operator <(const DeviceType & other) const64 bool DeviceType::operator<(const DeviceType& other) const {
65   return type_ < other.type_;
66 }
67 
operator ==(const DeviceType & other) const68 bool DeviceType::operator==(const DeviceType& other) const {
69   return type_ == other.type_;
70 }
71 
operator <<(std::ostream & os,const DeviceType & d)72 std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
73   os << d.type();
74   return os;
75 }
76 
77 const char* const DEVICE_DEFAULT = "DEFAULT";
78 const char* const DEVICE_CPU = "CPU";
79 const char* const DEVICE_GPU = "GPU";
80 const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
81 
82 const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
83 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
84     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
85 const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
86 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
87 
88 namespace {
DataTypeStringInternal(DataType dtype)89 string DataTypeStringInternal(DataType dtype) {
90   switch (dtype) {
91     case DT_INVALID:
92       return "INVALID";
93     case DT_FLOAT:
94       return "float";
95     case DT_DOUBLE:
96       return "double";
97     case DT_INT32:
98       return "int32";
99     case DT_UINT32:
100       return "uint32";
101     case DT_UINT8:
102       return "uint8";
103     case DT_UINT16:
104       return "uint16";
105     case DT_INT16:
106       return "int16";
107     case DT_INT8:
108       return "int8";
109     case DT_STRING:
110       return "string";
111     case DT_COMPLEX64:
112       return "complex64";
113     case DT_COMPLEX128:
114       return "complex128";
115     case DT_INT64:
116       return "int64";
117     case DT_UINT64:
118       return "uint64";
119     case DT_BOOL:
120       return "bool";
121     case DT_QINT8:
122       return "qint8";
123     case DT_QUINT8:
124       return "quint8";
125     case DT_QUINT16:
126       return "quint16";
127     case DT_QINT16:
128       return "qint16";
129     case DT_QINT32:
130       return "qint32";
131     case DT_BFLOAT16:
132       return "bfloat16";
133     case DT_HALF:
134       return "half";
135     case DT_RESOURCE:
136       return "resource";
137     case DT_VARIANT:
138       return "variant";
139     default:
140       LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
141       return strings::StrCat("unknown dtype enum (", dtype, ")");
142   }
143 }
144 }  // end namespace
145 
DataTypeString(DataType dtype)146 string DataTypeString(DataType dtype) {
147   if (IsRefType(dtype)) {
148     DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
149     return strings::StrCat(DataTypeStringInternal(non_ref), "_ref");
150   }
151   return DataTypeStringInternal(dtype);
152 }
153 
DataTypeFromString(StringPiece sp,DataType * dt)154 bool DataTypeFromString(StringPiece sp, DataType* dt) {
155   if (str_util::EndsWith(sp, "_ref")) {
156     sp.remove_suffix(4);
157     DataType non_ref;
158     if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
159       *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
160       return true;
161     } else {
162       return false;
163     }
164   }
165 
166   if (sp == "float" || sp == "float32") {
167     *dt = DT_FLOAT;
168     return true;
169   } else if (sp == "double" || sp == "float64") {
170     *dt = DT_DOUBLE;
171     return true;
172   } else if (sp == "int32") {
173     *dt = DT_INT32;
174     return true;
175   } else if (sp == "uint32") {
176     *dt = DT_UINT32;
177     return true;
178   } else if (sp == "uint8") {
179     *dt = DT_UINT8;
180     return true;
181   } else if (sp == "uint16") {
182     *dt = DT_UINT16;
183     return true;
184   } else if (sp == "int16") {
185     *dt = DT_INT16;
186     return true;
187   } else if (sp == "int8") {
188     *dt = DT_INT8;
189     return true;
190   } else if (sp == "string") {
191     *dt = DT_STRING;
192     return true;
193   } else if (sp == "complex64") {
194     *dt = DT_COMPLEX64;
195     return true;
196   } else if (sp == "complex128") {
197     *dt = DT_COMPLEX128;
198     return true;
199   } else if (sp == "int64") {
200     *dt = DT_INT64;
201     return true;
202   } else if (sp == "uint64") {
203     *dt = DT_UINT64;
204     return true;
205   } else if (sp == "bool") {
206     *dt = DT_BOOL;
207     return true;
208   } else if (sp == "qint8") {
209     *dt = DT_QINT8;
210     return true;
211   } else if (sp == "quint8") {
212     *dt = DT_QUINT8;
213     return true;
214   } else if (sp == "qint16") {
215     *dt = DT_QINT16;
216     return true;
217   } else if (sp == "quint16") {
218     *dt = DT_QUINT16;
219     return true;
220   } else if (sp == "qint32") {
221     *dt = DT_QINT32;
222     return true;
223   } else if (sp == "bfloat16") {
224     *dt = DT_BFLOAT16;
225     return true;
226   } else if (sp == "half" || sp == "float16") {
227     *dt = DT_HALF;
228     return true;
229   } else if (sp == "resource") {
230     *dt = DT_RESOURCE;
231     return true;
232   } else if (sp == "variant") {
233     *dt = DT_VARIANT;
234     return true;
235   }
236   return false;
237 }
238 
DeviceTypeString(const DeviceType & device_type)239 string DeviceTypeString(const DeviceType& device_type) {
240   return device_type.type();
241 }
242 
DataTypeSliceString(const DataTypeSlice types)243 string DataTypeSliceString(const DataTypeSlice types) {
244   string out;
245   for (auto it = types.begin(); it != types.end(); ++it) {
246     strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
247                        DataTypeString(*it));
248   }
249   return out;
250 }
251 
DataTypeAlwaysOnHost(DataType dt)252 bool DataTypeAlwaysOnHost(DataType dt) {
253   // Includes DT_STRING and DT_RESOURCE.
254   switch (dt) {
255     case DT_STRING:
256     case DT_STRING_REF:
257     case DT_RESOURCE:
258       return true;
259     default:
260       return false;
261   }
262 }
263 
DataTypeSize(DataType dt)264 int DataTypeSize(DataType dt) {
265 #define CASE(T)                  \
266   case DataTypeToEnum<T>::value: \
267     return sizeof(T);
268   switch (dt) {
269     TF_CALL_POD_TYPES(CASE);
270     TF_CALL_QUANTIZED_TYPES(CASE);
271     // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since
272     // they are not supported widely, but are explicitly listed here for
273     // bitcast.
274     TF_CALL_qint16(CASE);
275     TF_CALL_quint16(CASE);
276 
277     default:
278       return 0;
279   }
280 #undef CASE
281 }
282 
283 // Define DataTypeToEnum<T>::value.
284 #define DEFINE_DATATYPETOENUM_VALUE(TYPE) \
285   constexpr DataType DataTypeToEnum<TYPE>::value;
286 
287 DEFINE_DATATYPETOENUM_VALUE(float);
288 DEFINE_DATATYPETOENUM_VALUE(double);
289 DEFINE_DATATYPETOENUM_VALUE(int32);
290 DEFINE_DATATYPETOENUM_VALUE(uint32);
291 DEFINE_DATATYPETOENUM_VALUE(uint16);
292 DEFINE_DATATYPETOENUM_VALUE(uint8);
293 DEFINE_DATATYPETOENUM_VALUE(int16);
294 DEFINE_DATATYPETOENUM_VALUE(int8);
295 DEFINE_DATATYPETOENUM_VALUE(tstring);
296 DEFINE_DATATYPETOENUM_VALUE(complex64);
297 DEFINE_DATATYPETOENUM_VALUE(complex128);
298 DEFINE_DATATYPETOENUM_VALUE(int64);
299 DEFINE_DATATYPETOENUM_VALUE(uint64);
300 DEFINE_DATATYPETOENUM_VALUE(bool);
301 DEFINE_DATATYPETOENUM_VALUE(qint8);
302 DEFINE_DATATYPETOENUM_VALUE(quint8);
303 DEFINE_DATATYPETOENUM_VALUE(qint16);
304 DEFINE_DATATYPETOENUM_VALUE(quint16);
305 DEFINE_DATATYPETOENUM_VALUE(qint32);
306 DEFINE_DATATYPETOENUM_VALUE(bfloat16);
307 DEFINE_DATATYPETOENUM_VALUE(Eigen::half);
308 DEFINE_DATATYPETOENUM_VALUE(ResourceHandle);
309 DEFINE_DATATYPETOENUM_VALUE(Variant);
310 #undef DEFINE_DATATYPETOENUM_VALUE
311 
312 }  // namespace tensorflow
313