• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/types.h"
17 #include "tensorflow/core/framework/register_types.h"
18 #include "tensorflow/core/lib/strings/str_util.h"
19 #include "tensorflow/core/lib/strings/strcat.h"
20 #include "tensorflow/core/platform/logging.h"
21 
22 namespace tensorflow {
23 
24 struct DataTypeHasher {
operator ()tensorflow::DataTypeHasher25   std::size_t operator()(const DataType& k) const {
26     return std::hash<int>()(static_cast<int>(k));
27   }
28 };
29 
30 // Mapping from some of the DType fields, for backward compatibility. All other
31 // dtypes are mapped to TFT_ANY, but can be added here if a counterpart is
32 // defined.
33 auto* DT_TO_FT = new std::unordered_map<DataType, FullTypeId, DataTypeHasher>({
34     {DT_FLOAT, TFT_FLOAT},
35     {DT_DOUBLE, TFT_DOUBLE},
36     {DT_INT32, TFT_INT32},
37     {DT_UINT8, TFT_UINT8},
38     {DT_INT16, TFT_INT16},
39     {DT_INT8, TFT_INT8},
40     {DT_STRING, TFT_STRING},
41     {DT_COMPLEX64, TFT_COMPLEX64},
42     {DT_INT64, TFT_INT64},
43     {DT_BOOL, TFT_BOOL},
44     {DT_UINT16, TFT_UINT16},
45     {DT_COMPLEX128, TFT_COMPLEX128},
46     {DT_HALF, TFT_HALF},
47     {DT_UINT32, TFT_UINT32},
48     {DT_UINT64, TFT_UINT64},
49     {DT_VARIANT, TFT_LEGACY_VARIANT},
50 });
51 
map_dtype_to_tensor(const DataType & dtype,FullTypeDef & t)52 void map_dtype_to_tensor(const DataType& dtype, FullTypeDef& t) {
53   t.Clear();
54 
55   const auto& mapped = DT_TO_FT->find(dtype);
56   // Only map known types, everything else remains unset. This is so that we
57   // only set the most specific type when it is fully known. For example, if the
58   // dtype is DT_VARIANT, then we don't know much and opt to assume that
59   // the type is unset, rather than TFT_ANY.
60   if (mapped != DT_TO_FT->end()) {
61     t.set_type_id(mapped->second);
62   }
63 }
64 
operator <(const DeviceType & other) const65 bool DeviceType::operator<(const DeviceType& other) const {
66   return type_ < other.type_;
67 }
68 
operator ==(const DeviceType & other) const69 bool DeviceType::operator==(const DeviceType& other) const {
70   return type_ == other.type_;
71 }
72 
operator <<(std::ostream & os,const DeviceType & d)73 std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
74   os << d.type();
75   return os;
76 }
77 
78 const char* const DEVICE_DEFAULT = "DEFAULT";
79 const char* const DEVICE_CPU = "CPU";
80 const char* const DEVICE_GPU = "GPU";
81 const char* const DEVICE_TPU = "TPU";
82 const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
83 
84 const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
85 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
86     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
87 const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
88 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
89 
90 namespace {
DataTypeStringInternal(DataType dtype)91 string DataTypeStringInternal(DataType dtype) {
92   switch (dtype) {
93     case DT_INVALID:
94       return "INVALID";
95     case DT_FLOAT:
96       return "float";
97     case DT_DOUBLE:
98       return "double";
99     case DT_INT32:
100       return "int32";
101     case DT_UINT32:
102       return "uint32";
103     case DT_UINT8:
104       return "uint8";
105     case DT_UINT16:
106       return "uint16";
107     case DT_INT16:
108       return "int16";
109     case DT_INT8:
110       return "int8";
111     case DT_STRING:
112       return "string";
113     case DT_COMPLEX64:
114       return "complex64";
115     case DT_COMPLEX128:
116       return "complex128";
117     case DT_INT64:
118       return "int64";
119     case DT_UINT64:
120       return "uint64";
121     case DT_BOOL:
122       return "bool";
123     case DT_QINT8:
124       return "qint8";
125     case DT_QUINT8:
126       return "quint8";
127     case DT_QUINT16:
128       return "quint16";
129     case DT_QINT16:
130       return "qint16";
131     case DT_QINT32:
132       return "qint32";
133     case DT_BFLOAT16:
134       return "bfloat16";
135     case DT_HALF:
136       return "half";
137     case DT_RESOURCE:
138       return "resource";
139     case DT_VARIANT:
140       return "variant";
141     default:
142       LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
143       return strings::StrCat("unknown dtype enum (", dtype, ")");
144   }
145 }
146 }  // end namespace
147 
DataTypeString(DataType dtype)148 string DataTypeString(DataType dtype) {
149   if (IsRefType(dtype)) {
150     DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
151     return strings::StrCat(DataTypeStringInternal(non_ref), "_ref");
152   }
153   return DataTypeStringInternal(dtype);
154 }
155 
DataTypeFromString(StringPiece sp,DataType * dt)156 bool DataTypeFromString(StringPiece sp, DataType* dt) {
157   if (str_util::EndsWith(sp, "_ref")) {
158     sp.remove_suffix(4);
159     DataType non_ref;
160     if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
161       *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
162       return true;
163     } else {
164       return false;
165     }
166   }
167 
168   if (sp == "float" || sp == "float32") {
169     *dt = DT_FLOAT;
170     return true;
171   } else if (sp == "double" || sp == "float64") {
172     *dt = DT_DOUBLE;
173     return true;
174   } else if (sp == "int32") {
175     *dt = DT_INT32;
176     return true;
177   } else if (sp == "uint32") {
178     *dt = DT_UINT32;
179     return true;
180   } else if (sp == "uint8") {
181     *dt = DT_UINT8;
182     return true;
183   } else if (sp == "uint16") {
184     *dt = DT_UINT16;
185     return true;
186   } else if (sp == "int16") {
187     *dt = DT_INT16;
188     return true;
189   } else if (sp == "int8") {
190     *dt = DT_INT8;
191     return true;
192   } else if (sp == "string") {
193     *dt = DT_STRING;
194     return true;
195   } else if (sp == "complex64") {
196     *dt = DT_COMPLEX64;
197     return true;
198   } else if (sp == "complex128") {
199     *dt = DT_COMPLEX128;
200     return true;
201   } else if (sp == "int64") {
202     *dt = DT_INT64;
203     return true;
204   } else if (sp == "uint64") {
205     *dt = DT_UINT64;
206     return true;
207   } else if (sp == "bool") {
208     *dt = DT_BOOL;
209     return true;
210   } else if (sp == "qint8") {
211     *dt = DT_QINT8;
212     return true;
213   } else if (sp == "quint8") {
214     *dt = DT_QUINT8;
215     return true;
216   } else if (sp == "qint16") {
217     *dt = DT_QINT16;
218     return true;
219   } else if (sp == "quint16") {
220     *dt = DT_QUINT16;
221     return true;
222   } else if (sp == "qint32") {
223     *dt = DT_QINT32;
224     return true;
225   } else if (sp == "bfloat16") {
226     *dt = DT_BFLOAT16;
227     return true;
228   } else if (sp == "half" || sp == "float16") {
229     *dt = DT_HALF;
230     return true;
231   } else if (sp == "resource") {
232     *dt = DT_RESOURCE;
233     return true;
234   } else if (sp == "variant") {
235     *dt = DT_VARIANT;
236     return true;
237   }
238   return false;
239 }
240 
DeviceTypeString(const DeviceType & device_type)241 string DeviceTypeString(const DeviceType& device_type) {
242   return device_type.type();
243 }
244 
DataTypeSliceString(const DataTypeSlice types)245 string DataTypeSliceString(const DataTypeSlice types) {
246   string out;
247   for (auto it = types.begin(); it != types.end(); ++it) {
248     strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
249                        DataTypeString(*it));
250   }
251   return out;
252 }
253 
DataTypeAlwaysOnHost(DataType dt)254 bool DataTypeAlwaysOnHost(DataType dt) {
255   // Includes DT_STRING and DT_RESOURCE.
256   switch (dt) {
257     case DT_STRING:
258     case DT_STRING_REF:
259     case DT_RESOURCE:
260       return true;
261     default:
262       return false;
263   }
264 }
265 
DataTypeSize(DataType dt)266 int DataTypeSize(DataType dt) {
267 #define CASE(T)                  \
268   case DataTypeToEnum<T>::value: \
269     return sizeof(T);
270   switch (dt) {
271     TF_CALL_POD_TYPES(CASE);
272     TF_CALL_QUANTIZED_TYPES(CASE);
273     // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since
274     // they are not supported widely, but are explicitly listed here for
275     // bitcast.
276     TF_CALL_qint16(CASE);
277     TF_CALL_quint16(CASE);
278 
279     default:
280       return 0;
281   }
282 #undef CASE
283 }
284 
285 // Define DataTypeToEnum<T>::value.
286 #define DEFINE_DATATYPETOENUM_VALUE(TYPE) \
287   constexpr DataType DataTypeToEnum<TYPE>::value;
288 
289 DEFINE_DATATYPETOENUM_VALUE(float);
290 DEFINE_DATATYPETOENUM_VALUE(double);
291 DEFINE_DATATYPETOENUM_VALUE(int32);
292 DEFINE_DATATYPETOENUM_VALUE(uint32);
293 DEFINE_DATATYPETOENUM_VALUE(uint16);
294 DEFINE_DATATYPETOENUM_VALUE(uint8);
295 DEFINE_DATATYPETOENUM_VALUE(int16);
296 DEFINE_DATATYPETOENUM_VALUE(int8);
297 DEFINE_DATATYPETOENUM_VALUE(tstring);
298 DEFINE_DATATYPETOENUM_VALUE(complex64);
299 DEFINE_DATATYPETOENUM_VALUE(complex128);
300 DEFINE_DATATYPETOENUM_VALUE(int64_t);
301 DEFINE_DATATYPETOENUM_VALUE(uint64);
302 DEFINE_DATATYPETOENUM_VALUE(bool);
303 DEFINE_DATATYPETOENUM_VALUE(qint8);
304 DEFINE_DATATYPETOENUM_VALUE(quint8);
305 DEFINE_DATATYPETOENUM_VALUE(qint16);
306 DEFINE_DATATYPETOENUM_VALUE(quint16);
307 DEFINE_DATATYPETOENUM_VALUE(qint32);
308 DEFINE_DATATYPETOENUM_VALUE(bfloat16);
309 DEFINE_DATATYPETOENUM_VALUE(Eigen::half);
310 DEFINE_DATATYPETOENUM_VALUE(ResourceHandle);
311 DEFINE_DATATYPETOENUM_VALUE(Variant);
312 #undef DEFINE_DATATYPETOENUM_VALUE
313 
314 }  // namespace tensorflow
315