• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
18 
19 #include <map>
20 #include <set>
21 #include <string>
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 // Disable clang-format to prevent 'FixedPoint' header from being included
25 // before 'Tensor' header on which it depends.
26 // clang-format off
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
28 // clang-format on
29 #include "tensorflow/core/framework/bfloat16.h"
30 #include "tensorflow/core/framework/numeric_types.h"
31 #include "tensorflow/core/framework/resource_handle.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/lib/gtl/inlined_vector.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace tensorflow {
40 
41 class Variant;
42 
43 // MemoryType is used to describe whether input or output Tensors of
44 // an OpKernel should reside in "Host memory" (e.g., CPU memory) or
45 // "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
46 // devices).
47 enum MemoryType {
48   DEVICE_MEMORY = 0,
49   HOST_MEMORY = 1,
50 };
51 
52 // A DeviceType is just a string, but we wrap it up in a class to give
53 // some type checking as we're passing these around
54 class DeviceType {
55  public:
DeviceType(const char * type)56   DeviceType(const char* type)  // NOLINT(runtime/explicit)
57       : type_(type) {}
58 
DeviceType(StringPiece type)59   explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {}
60 
type()61   const char* type() const { return type_.c_str(); }
type_string()62   const std::string& type_string() const { return type_; }
63 
64   bool operator<(const DeviceType& other) const;
65   bool operator==(const DeviceType& other) const;
66   bool operator!=(const DeviceType& other) const { return !(*this == other); }
67 
68  private:
69   std::string type_;
70 };
71 std::ostream& operator<<(std::ostream& os, const DeviceType& d);
72 
73 // Convenient constants that can be passed to a DeviceType constructor
74 TF_EXPORT extern const char* const DEVICE_DEFAULT;     // "DEFAULT"
75 TF_EXPORT extern const char* const DEVICE_CPU;         // "CPU"
76 TF_EXPORT extern const char* const DEVICE_GPU;         // "GPU"
77 TF_EXPORT extern const char* const DEVICE_TPU_SYSTEM;  // "TPU_SYSTEM"
78 
79 template <typename Device>
80 struct DeviceName {};
81 
82 template <>
83 struct DeviceName<Eigen::ThreadPoolDevice> {
84   static const std::string value;
85 };
86 
87 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
88     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
89 template <>
90 struct DeviceName<Eigen::GpuDevice> {
91   static const std::string value;
92 };
93 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
94 
95 
96 typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector;
97 typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice;
98 
99 typedef gtl::InlinedVector<DataType, 4> DataTypeVector;
100 typedef gtl::ArraySlice<DataType> DataTypeSlice;
101 
102 typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector;
103 typedef gtl::InlinedVector<std::pair<DeviceType, int32>, 4>
104     PrioritizedDeviceTypeVector;
105 
106 // Convert the enums to strings for errors:
107 std::string DataTypeString(DataType dtype);
108 std::string DeviceTypeString(const DeviceType& device_type);
109 std::string DataTypeSliceString(const DataTypeSlice dtypes);
110 inline std::string DataTypeVectorString(const DataTypeVector& dtypes) {
111   return DataTypeSliceString(dtypes);
112 }
113 
114 // DataTypeSet represents a set of DataType values as a simple and efficient
115 // bit mask.  Note that DataTypeSet cannot represent all DataType values; it
116 // cannot represent any of the DT_*_REF values.
117 class DataTypeSet {
118  private:
119   const uint32 mask_;
120 
121   static constexpr uint32 kNumBits = 32;
122 
123  public:
124   constexpr DataTypeSet(const DataTypeSet& other) : mask_(other.mask_) {}
125   explicit constexpr DataTypeSet(uint32 mask) : mask_(mask) {}
126 
127   constexpr bool Contains(DataType dt) const {
128     return (static_cast<uint32>(dt) < kNumBits) &&
129            ((mask_ >> static_cast<uint32>(dt)) & 1u) != 0u;
130   }
131 
132   class Iterator {
133     const DataTypeSet& set_;
134     uint32 pos_;
135 
136    public:
137     Iterator(const DataTypeSet& set, uint32 pos) : set_(set), pos_(pos) {
138       DCHECK_LE(pos, kNumBits);
139     }
140     DataType operator*() const { return static_cast<DataType>(pos_); }
141     Iterator& operator++() {
142       ++pos_;
143       DCHECK_LE(pos_, kNumBits);
144       if (pos_ < kNumBits) {
145         uint32 remaining_mask = set_.mask_ >> pos_;
146         if (remaining_mask != 0u) {
147           pos_ += ctz_uint32(remaining_mask);
148         }
149       }
150       DCHECK_LE(pos_, kNumBits);
151       return *this;
152     }
153     bool operator==(const Iterator& other) const { return pos_ == other.pos_; }
154     bool operator!=(const Iterator& other) const { return !(*this == other); }
155     size_t operator-(const Iterator& other) const {
156       return this->pos_ - other.pos_;
157     }
158   };
159 
160   static uint32 ctz_uint32(uint32 x) {
161     DCHECK_NE(x, 0u);
162 #ifdef __GNUC__
163     return __builtin_ctz(x);
164 #else
165     uint32 n = 0u;
166     while ((x & 1u) == 0u) {
167       x >>= 1;
168       ++n;
169     }
170     return n;
171 #endif
172   }
173 
174   static uint32 clz_uint32(uint32 x) {
175     DCHECK_NE(x, 0u);
176 #ifdef __GNUC__
177     return __builtin_clz(x);
178 #else
179     uint32 n = 0u;
180     while ((x >> (kNumBits - 1u)) == 0u) {
181       x <<= 1;
182       ++n;
183     }
184     return n;
185 #endif
186   }
187 
188   Iterator begin() const {
189     // The begin position is the index of the first bit set to 1 in the entire
190     // bit mask. If there are no bits set to 1, then the index is 0.
191     if (mask_ != 0) {
192       return Iterator(*this, ctz_uint32(mask_));
193     }
194     // The set is empty.
195     return Iterator(*this, 0);
196   }
197 
198   Iterator end() const {
199     // The end position is the index of the highest bit that is set, plus 1.
200     // If there are no bits set to 1, then the index is 0.
201     if (mask_ != 0) {
202       return Iterator(*this, kNumBits - clz_uint32(mask_));
203     }
204     // The set is empty.
205     return Iterator(*this, 0);
206   }
207 
208   size_t size() const {
209 #if defined(__GNUC__)
210     return __builtin_popcount(mask_);
211 #else
212     size_t n = 0;
213     uint32 x = mask_;
214     while (x > 0) {
215       n += x & 1u;
216       x >>= 1;
217     }
218     return n;
219 #endif
220   }
221 
222   constexpr DataTypeSet operator|(const DataTypeSet& other) const {
223     return DataTypeSet(mask_ | other.mask_);
224   }
225 };
226 
227 // If "sp" names a valid type, store it in "*dt" and return true.  Otherwise,
228 // return false.
229 bool DataTypeFromString(StringPiece sp, DataType* dt);
230 
231 constexpr inline DataTypeSet ToSet(DataType dt) {
232   return DataTypeSet(1u << static_cast<uint32>(dt));
233 }
234 
235 // DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc.
236 enum { kDataTypeRefOffset = 100 };
237 inline bool IsRefType(DataType dtype) {
238   return dtype > static_cast<DataType>(kDataTypeRefOffset);
239 }
240 inline DataType MakeRefType(DataType dtype) {
241   DCHECK(!IsRefType(dtype));
242   return static_cast<DataType>(dtype + kDataTypeRefOffset);
243 }
244 inline DataType RemoveRefType(DataType dtype) {
245   DCHECK(IsRefType(dtype));
246   return static_cast<DataType>(dtype - kDataTypeRefOffset);
247 }
248 inline DataType BaseType(DataType dtype) {
249   return IsRefType(dtype) ? RemoveRefType(dtype) : dtype;
250 }
251 
252 // Returns true if the actual type is the same as or ref of the expected type.
253 inline bool TypesCompatible(DataType expected, DataType actual) {
254   return expected == actual || expected == BaseType(actual);
255 }
256 
257 // Does not include _ref types.
258 constexpr DataTypeSet kAllTypes =
259     ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT8) |
260     ToSet(DT_INT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | ToSet(DT_STRING) |
261     ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) |
262     ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) |
263     ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_RESOURCE) |
264     ToSet(DT_VARIANT) | ToSet(DT_UINT32) | ToSet(DT_UINT64) |
265     ToSet(DT_BFLOAT16);
266 inline const DataTypeSet& AllTypes() { return kAllTypes; }
267 
268 #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION)
269 
270 // Types that support '<' and '>'.
271 constexpr DataTypeSet kRealNumberTypes =
272     ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) |
273     ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_INT8) | ToSet(DT_UINT16) |
274     ToSet(DT_HALF) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16);
275 inline const DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
276 
277 // Return the list of all numeric types.
278 // Includes complex and quantized types.
279 // NOTE: On Android, we only include the float and int32 types for now.
280 const DataTypeSet kNumberTypes =
281     ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT64) | ToSet(DT_INT32) |
282     ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) |
283     ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_QINT8) |
284     ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_UINT32) |
285     ToSet(DT_UINT64) | ToSet(DT_BFLOAT16);
286 inline const DataTypeSet& NumberTypes() { return kNumberTypes; }
287 
288 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
289                                         ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
290                                         ToSet(DT_QINT32);
291 inline const DataTypeSet& QuantizedTypes() { return kQuantizedTypes; }
292 
293 // Types that support '<' and '>', including quantized types.
294 const DataTypeSet kRealAndQuantizedTypes =
295     ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) |
296     ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) |
297     ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
298     ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_BFLOAT16);
299 inline const DataTypeSet& RealAndQuantizedTypes() {
300   return kRealAndQuantizedTypes;
301 }
302 
303 #elif defined(__ANDROID_TYPES_FULL__)
304 
305 constexpr DataTypeSet kRealNumberTypes =
306     ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_HALF);
307 inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
308 
309 constexpr DataTypeSet kNumberTypes =
310     ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) |
311     ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF);
312 inline DataTypeSet NumberTypes() { return kNumberTypes; }
313 
314 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
315                                         ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
316                                         ToSet(DT_QINT32);
317 inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; }
318 
319 constexpr DataTypeSet kRealAndQuantizedTypes =
320     ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) |
321     ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) |
322     ToSet(DT_HALF);
323 inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; }
324 
325 #else  // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__)
326 
327 constexpr DataTypeSet kRealNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32);
328 inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
329 
330 constexpr DataTypeSet kNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32) |
331                                      ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
332                                      ToSet(DT_QINT32);
333 inline DataTypeSet NumberTypes() { return kNumberTypes; }
334 
335 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
336                                         ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
337                                         ToSet(DT_QINT32);
338 inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; }
339 
340 constexpr DataTypeSet kRealAndQuantizedTypes =
341     ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
342     ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32);
343 inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; }
344 
345 #endif  // defined(IS_MOBILE_PLATFORM)
346 
347 // Validates type T for whether it is a supported DataType.
348 template <class T>
349 struct IsValidDataType;
350 
351 // DataTypeToEnum<T>::v() and DataTypeToEnum<T>::value are the DataType
352 // constants for T, e.g. DataTypeToEnum<float>::v() is DT_FLOAT.
353 template <class T>
354 struct DataTypeToEnum {
355   static_assert(IsValidDataType<T>::value, "Specified Data Type not supported");
356 };  // Specializations below
357 
358 // EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
359 // EnumToDataType<DT_FLOAT>::Type is float.
360 template <DataType VALUE>
361 struct EnumToDataType {};  // Specializations below
362 
363 // Template specialization for both DataTypeToEnum and EnumToDataType.
364 #define MATCH_TYPE_AND_ENUM(TYPE, ENUM)                 \
365   template <>                                           \
366   struct DataTypeToEnum<TYPE> {                         \
367     static DataType v() { return ENUM; }                \
368     static DataType ref() { return MakeRefType(ENUM); } \
369     static constexpr DataType value = ENUM;             \
370   };                                                    \
371   template <>                                           \
372   struct IsValidDataType<TYPE> {                        \
373     static constexpr bool value = true;                 \
374   };                                                    \
375   template <>                                           \
376   struct EnumToDataType<ENUM> {                         \
377     typedef TYPE Type;                                  \
378   }
379 
380 MATCH_TYPE_AND_ENUM(float, DT_FLOAT);
381 MATCH_TYPE_AND_ENUM(double, DT_DOUBLE);
382 MATCH_TYPE_AND_ENUM(int32, DT_INT32);
383 MATCH_TYPE_AND_ENUM(uint32, DT_UINT32);
384 MATCH_TYPE_AND_ENUM(uint16, DT_UINT16);
385 MATCH_TYPE_AND_ENUM(uint8, DT_UINT8);
386 MATCH_TYPE_AND_ENUM(int16, DT_INT16);
387 MATCH_TYPE_AND_ENUM(int8, DT_INT8);
388 MATCH_TYPE_AND_ENUM(tstring, DT_STRING);
389 MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64);
390 MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128);
391 MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
392 MATCH_TYPE_AND_ENUM(qint8, DT_QINT8);
393 MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8);
394 MATCH_TYPE_AND_ENUM(qint16, DT_QINT16);
395 MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16);
396 MATCH_TYPE_AND_ENUM(qint32, DT_QINT32);
397 MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16);
398 MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF);
399 MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE);
400 MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT);
401 
402 template <>
403 struct DataTypeToEnum<long> {
404   static DataType v() { return value; }
405   static DataType ref() { return MakeRefType(value); }
406   static constexpr DataType value = sizeof(long) == 4 ? DT_INT32 : DT_INT64;
407 };
408 template <>
409 struct IsValidDataType<long> {
410   static constexpr bool value = true;
411 };
412 template <>
413 struct EnumToDataType<DT_INT64> {
414   typedef tensorflow::int64 Type;
415 };
416 
417 template <>
418 struct DataTypeToEnum<unsigned long> {
419   static DataType v() { return value; }
420   static DataType ref() { return MakeRefType(value); }
421   static constexpr DataType value =
422       sizeof(unsigned long) == 4 ? DT_UINT32 : DT_UINT64;
423 };
424 template <>
425 struct IsValidDataType<unsigned long> {
426   static constexpr bool value = true;
427 };
428 template <>
429 struct EnumToDataType<DT_UINT64> {
430   typedef tensorflow::uint64 Type;
431 };
432 
433 template <>
434 struct DataTypeToEnum<long long> {
435   static DataType v() { return DT_INT64; }
436   static DataType ref() { return MakeRefType(DT_INT64); }
437   static constexpr DataType value = DT_INT64;
438 };
439 template <>
440 struct IsValidDataType<long long> {
441   static constexpr bool value = true;
442 };
443 
444 template <>
445 struct DataTypeToEnum<unsigned long long> {
446   static DataType v() { return DT_UINT64; }
447   static DataType ref() { return MakeRefType(DT_UINT64); }
448   static constexpr DataType value = DT_UINT64;
449 };
450 template <>
451 struct IsValidDataType<unsigned long long> {
452   static constexpr bool value = true;
453 };
454 
455 #undef MATCH_TYPE_AND_ENUM
456 
457 // All types not specialized are marked invalid.
458 template <class T>
459 struct IsValidDataType {
460   static constexpr bool value = false;
461 };
462 
463 // Extra validity checking; not part of public API.
464 static_assert(IsValidDataType<int64>::value, "Incorrect impl for int64");
465 static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32");
466 
467 // TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying
468 // is_simple<T> in tensor.cc (and possible choose a more general name?)
469 constexpr DataTypeSet kDataTypesCanUseMemcpy =
470     ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT32) |
471     ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) |
472     ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) |
473     ToSet(DT_UINT64) | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
474     ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) |
475     ToSet(DT_BFLOAT16) | ToSet(DT_HALF);
476 inline bool DataTypeCanUseMemcpy(DataType dt) {
477   return kDataTypesCanUseMemcpy.Contains(dt);
478 }
479 
480 // Returns true iff 'dt' is a real, non-quantized floating point type.
481 constexpr DataTypeSet kDataTypeIsFloating =
482     ToSet(DT_HALF) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE);
483 inline bool DataTypeIsFloating(DataType dt) {
484   return kDataTypeIsFloating.Contains(dt);
485 }
486 
487 // Returns true iff 'dt' is a complex type.
488 constexpr DataTypeSet kDataTypeIsComplex =
489     ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128);
490 inline bool DataTypeIsComplex(DataType dt) {
491   return kDataTypeIsComplex.Contains(dt);
492 }
493 
494 inline bool DataTypeIsQuantized(DataType dt) {
495   return kQuantizedTypes.Contains(dt);
496 }
497 
498 // Is the dtype nonquantized integral?
499 constexpr DataTypeSet kDataTypeIsInteger =
500     ToSet(DT_INT8) | ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_UINT16) |
501     ToSet(DT_INT32) | ToSet(DT_UINT32) | ToSet(DT_INT64) | ToSet(DT_UINT64);
502 inline bool DataTypeIsInteger(DataType dt) {
503   return kDataTypeIsInteger.Contains(dt);
504 }
505 
506 // Is the dtype a signed integral type?
507 constexpr DataTypeSet kDataTypeIsSigned =
508     ToSet(DT_INT8) | ToSet(DT_INT16) | ToSet(DT_INT32) | ToSet(DT_INT64);
509 inline bool DataTypeIsSigned(DataType dt) {
510   return kDataTypeIsSigned.Contains(dt);
511 }
512 
513 // Is the dtype an unsigned integral type?
514 constexpr DataTypeSet kDataTypeIsUnsigned =
515     ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT32) | ToSet(DT_UINT64);
516 inline bool DataTypeIsUnsigned(DataType dt) {
517   return kDataTypeIsUnsigned.Contains(dt);
518 }
519 
520 // Returns a 0 on failure
521 int DataTypeSize(DataType dt);
522 
523 // Returns HOST_MEMORY if `dtype` is always on host or is a DT_INT32,
524 // DEVICE_MEMORY otherwise.
525 MemoryType MTypeFromDType(const DataType dtype);
526 
527 // Returns HOST_MEMORY if `dtype` is always on host, DEVICE_MEMORY otherwise.
528 // The reason we have MTypeFromDType() and MTypeFromDTypeIntsOnDevice(): for
529 // GPUs, we would like to keep int operations on host for performance concerns.
530 // But for TPUs (and other devices), int operations are placed on device.
531 MemoryType MTypeFromDTypeIntsOnDevice(const DataType dtype);
532 
533 // Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE.
534 // For DT_RESOURCE, the handle always sits on host (even if the underlying
535 // object has device-allocated resources).
536 bool DataTypeAlwaysOnHost(DataType dt);
537 
538 }  // namespace tensorflow
539 
540 #endif  // TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
541