• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
18 
19 #include <map>
20 #include <set>
21 #include <string>
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 // Disable clang-format to prevent 'FixedPoint' header from being included
25 // before 'Tensor' header on which it depends.
26 // clang-format off
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
28 // clang-format on
29 #include "tensorflow/core/framework/bfloat16.h"
30 #include "tensorflow/core/framework/full_type.pb.h"
31 #include "tensorflow/core/framework/numeric_types.h"
32 #include "tensorflow/core/framework/resource_handle.h"
33 #include "tensorflow/core/framework/types.pb.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/gtl/array_slice.h"
36 #include "tensorflow/core/lib/gtl/inlined_vector.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace tensorflow {
41 
42 class Variant;
43 
44 // MemoryType is used to describe whether input or output Tensors of
45 // an OpKernel should reside in "Host memory" (e.g., CPU memory) or
46 // "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
47 // devices).
48 enum MemoryType {
49   DEVICE_MEMORY = 0,
50   HOST_MEMORY = 1,
51 };
52 
53 // A DeviceType is just a string, but we wrap it up in a class to give
54 // some type checking as we're passing these around
55 class DeviceType {
56  public:
DeviceType(const char * type)57   DeviceType(const char* type)  // NOLINT(runtime/explicit)
58       : type_(type) {}
59 
DeviceType(StringPiece type)60   explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {}
61 
type()62   const char* type() const { return type_.c_str(); }
type_string()63   const std::string& type_string() const { return type_; }
64 
65   bool operator<(const DeviceType& other) const;
66   bool operator==(const DeviceType& other) const;
67   bool operator!=(const DeviceType& other) const { return !(*this == other); }
68 
69  private:
70   std::string type_;
71 };
72 std::ostream& operator<<(std::ostream& os, const DeviceType& d);
73 
74 // Convenient constants that can be passed to a DeviceType constructor
75 TF_EXPORT extern const char* const DEVICE_DEFAULT;     // "DEFAULT"
76 TF_EXPORT extern const char* const DEVICE_CPU;         // "CPU"
77 TF_EXPORT extern const char* const DEVICE_GPU;         // "GPU"
78 TF_EXPORT extern const char* const DEVICE_TPU_SYSTEM;  // "TPU_SYSTEM"
79 
80 template <typename Device>
81 struct DeviceName {};
82 
83 template <>
84 struct DeviceName<Eigen::ThreadPoolDevice> {
85   static const std::string value;
86 };
87 
88 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
89     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
90 template <>
91 struct DeviceName<Eigen::GpuDevice> {
92   static const std::string value;
93 };
94 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
95 
96 
97 typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector;
98 typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice;
99 
100 typedef gtl::InlinedVector<DataType, 4> DataTypeVector;
101 typedef gtl::ArraySlice<DataType> DataTypeSlice;
102 
103 typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector;
104 typedef gtl::InlinedVector<std::pair<DeviceType, int32>, 4>
105     PrioritizedDeviceTypeVector;
106 
107 // Convert the enums to strings for errors:
108 std::string DataTypeString(DataType dtype);
109 std::string DeviceTypeString(const DeviceType& device_type);
110 std::string DataTypeSliceString(const DataTypeSlice dtypes);
111 inline std::string DataTypeVectorString(const DataTypeVector& dtypes) {
112   return DataTypeSliceString(dtypes);
113 }
114 
115 // DataTypeSet represents a set of DataType values as a simple and efficient
116 // bit mask.  Note that DataTypeSet cannot represent all DataType values; it
117 // cannot represent any of the DT_*_REF values.
118 class DataTypeSet {
119  private:
120   const uint32 mask_;
121 
122   static constexpr uint32 kNumBits = 32;
123 
124  public:
125   constexpr DataTypeSet(const DataTypeSet& other) : mask_(other.mask_) {}
126   explicit constexpr DataTypeSet(uint32 mask) : mask_(mask) {}
127 
128   constexpr bool Contains(DataType dt) const {
129     return (static_cast<uint32>(dt) < kNumBits) &&
130            ((mask_ >> static_cast<uint32>(dt)) & 1u) != 0u;
131   }
132 
133   class Iterator {
134     const DataTypeSet& set_;
135     uint32 pos_;
136 
137    public:
138     Iterator(const DataTypeSet& set, uint32 pos) : set_(set), pos_(pos) {
139       DCHECK_LE(pos, kNumBits);
140     }
141     DataType operator*() const { return static_cast<DataType>(pos_); }
142     Iterator& operator++() {
143       ++pos_;
144       DCHECK_LE(pos_, kNumBits);
145       if (pos_ < kNumBits) {
146         uint32 remaining_mask = set_.mask_ >> pos_;
147         if (remaining_mask != 0u) {
148           pos_ += ctz_uint32(remaining_mask);
149         }
150       }
151       DCHECK_LE(pos_, kNumBits);
152       return *this;
153     }
154     bool operator==(const Iterator& other) const { return pos_ == other.pos_; }
155     bool operator!=(const Iterator& other) const { return !(*this == other); }
156     size_t operator-(const Iterator& other) const {
157       return this->pos_ - other.pos_;
158     }
159   };
160 
161   static uint32 ctz_uint32(uint32 x) {
162     DCHECK_NE(x, 0u);
163 #ifdef __GNUC__
164     return __builtin_ctz(x);
165 #else
166     uint32 n = 0u;
167     while ((x & 1u) == 0u) {
168       x >>= 1;
169       ++n;
170     }
171     return n;
172 #endif
173   }
174 
175   static uint32 clz_uint32(uint32 x) {
176     DCHECK_NE(x, 0u);
177 #ifdef __GNUC__
178     return __builtin_clz(x);
179 #else
180     uint32 n = 0u;
181     while ((x >> (kNumBits - 1u)) == 0u) {
182       x <<= 1;
183       ++n;
184     }
185     return n;
186 #endif
187   }
188 
189   Iterator begin() const {
190     // The begin position is the index of the first bit set to 1 in the entire
191     // bit mask. If there are no bits set to 1, then the index is 0.
192     if (mask_ != 0) {
193       return Iterator(*this, ctz_uint32(mask_));
194     }
195     // The set is empty.
196     return Iterator(*this, 0);
197   }
198 
199   Iterator end() const {
200     // The end position is the index of the highest bit that is set, plus 1.
201     // If there are no bits set to 1, then the index is 0.
202     if (mask_ != 0) {
203       return Iterator(*this, kNumBits - clz_uint32(mask_));
204     }
205     // The set is empty.
206     return Iterator(*this, 0);
207   }
208 
209   size_t size() const {
210 #if defined(__GNUC__)
211     return __builtin_popcount(mask_);
212 #else
213     size_t n = 0;
214     uint32 x = mask_;
215     while (x > 0) {
216       n += x & 1u;
217       x >>= 1;
218     }
219     return n;
220 #endif
221   }
222 
223   constexpr DataTypeSet operator|(const DataTypeSet& other) const {
224     return DataTypeSet(mask_ | other.mask_);
225   }
226 };
227 
228 // If "sp" names a valid type, store it in "*dt" and return true.  Otherwise,
229 // return false.
230 bool DataTypeFromString(StringPiece sp, DataType* dt);
231 
232 constexpr inline DataTypeSet ToSet(DataType dt) {
233   return DataTypeSet(1u << static_cast<uint32>(dt));
234 }
235 
236 // DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc.
237 enum { kDataTypeRefOffset = 100 };
238 inline bool IsRefType(DataType dtype) {
239   return dtype > static_cast<DataType>(kDataTypeRefOffset);
240 }
241 inline DataType MakeRefType(DataType dtype) {
242   DCHECK(!IsRefType(dtype));
243   return static_cast<DataType>(dtype + kDataTypeRefOffset);
244 }
245 inline DataType RemoveRefType(DataType dtype) {
246   DCHECK(IsRefType(dtype));
247   return static_cast<DataType>(dtype - kDataTypeRefOffset);
248 }
249 inline DataType BaseType(DataType dtype) {
250   return IsRefType(dtype) ? RemoveRefType(dtype) : dtype;
251 }
252 
253 // Returns true if the actual type is the same as or ref of the expected type.
254 inline bool TypesCompatible(DataType expected, DataType actual) {
255   return expected == actual || expected == BaseType(actual);
256 }
257 
258 // Does not include _ref types.
259 constexpr DataTypeSet kAllTypes =
260     ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT8) |
261     ToSet(DT_INT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | ToSet(DT_STRING) |
262     ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) |
263     ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) |
264     ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_RESOURCE) |
265     ToSet(DT_VARIANT) | ToSet(DT_UINT32) | ToSet(DT_UINT64) |
266     ToSet(DT_BFLOAT16);
267 inline const DataTypeSet& AllTypes() { return kAllTypes; }
268 
269 #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION)
270 
271 // Types that support '<' and '>'.
272 constexpr DataTypeSet kRealNumberTypes =
273     ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) |
274     ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_INT8) | ToSet(DT_UINT16) |
275     ToSet(DT_HALF) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16);
276 inline const DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
277 
278 // Return the list of all numeric types.
279 // Includes complex and quantized types.
280 // NOTE: On Android, we only include the float and int32 types for now.
281 const DataTypeSet kNumberTypes =
282     ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT64) | ToSet(DT_INT32) |
283     ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) |
284     ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_QINT8) |
285     ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_UINT32) |
286     ToSet(DT_UINT64) | ToSet(DT_BFLOAT16);
287 inline const DataTypeSet& NumberTypes() { return kNumberTypes; }
288 
289 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
290                                         ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
291                                         ToSet(DT_QINT32);
292 inline const DataTypeSet& QuantizedTypes() { return kQuantizedTypes; }
293 
294 // Types that support '<' and '>', including quantized types.
295 const DataTypeSet kRealAndQuantizedTypes =
296     ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) |
297     ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) |
298     ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
299     ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_BFLOAT16);
300 inline const DataTypeSet& RealAndQuantizedTypes() {
301   return kRealAndQuantizedTypes;
302 }
303 
304 #elif defined(__ANDROID_TYPES_FULL__)
305 
306 constexpr DataTypeSet kRealNumberTypes =
307     ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_HALF);
308 inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
309 
310 constexpr DataTypeSet kNumberTypes =
311     ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) |
312     ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF);
313 inline DataTypeSet NumberTypes() { return kNumberTypes; }
314 
315 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
316                                         ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
317                                         ToSet(DT_QINT32);
318 inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; }
319 
320 constexpr DataTypeSet kRealAndQuantizedTypes =
321     ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) |
322     ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) |
323     ToSet(DT_HALF);
324 inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; }
325 
326 #else  // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__)
327 
328 constexpr DataTypeSet kRealNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32);
329 inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; }
330 
331 constexpr DataTypeSet kNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32) |
332                                      ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
333                                      ToSet(DT_QINT32);
334 inline DataTypeSet NumberTypes() { return kNumberTypes; }
335 
336 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
337                                         ToSet(DT_QINT16) | ToSet(DT_QUINT16) |
338                                         ToSet(DT_QINT32);
339 inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; }
340 
341 constexpr DataTypeSet kRealAndQuantizedTypes =
342     ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
343     ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32);
344 inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; }
345 
346 #endif  // defined(IS_MOBILE_PLATFORM)
347 
348 // Validates type T for whether it is a supported DataType.
349 template <class T>
350 struct IsValidDataType;
351 
352 // DataTypeToEnum<T>::v() and DataTypeToEnum<T>::value are the DataType
353 // constants for T, e.g. DataTypeToEnum<float>::v() is DT_FLOAT.
354 template <class T>
355 struct DataTypeToEnum {
356   static_assert(IsValidDataType<T>::value, "Specified Data Type not supported");
357 };  // Specializations below
358 
359 // EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
360 // EnumToDataType<DT_FLOAT>::Type is float.
361 template <DataType VALUE>
362 struct EnumToDataType {};  // Specializations below
363 
364 // Template specialization for both DataTypeToEnum and EnumToDataType.
365 #define MATCH_TYPE_AND_ENUM(TYPE, ENUM)                 \
366   template <>                                           \
367   struct DataTypeToEnum<TYPE> {                         \
368     static DataType v() { return ENUM; }                \
369     static DataType ref() { return MakeRefType(ENUM); } \
370     static constexpr DataType value = ENUM;             \
371   };                                                    \
372   template <>                                           \
373   struct IsValidDataType<TYPE> {                        \
374     static constexpr bool value = true;                 \
375   };                                                    \
376   template <>                                           \
377   struct EnumToDataType<ENUM> {                         \
378     typedef TYPE Type;                                  \
379   }
380 
381 MATCH_TYPE_AND_ENUM(float, DT_FLOAT);
382 MATCH_TYPE_AND_ENUM(double, DT_DOUBLE);
383 MATCH_TYPE_AND_ENUM(int32, DT_INT32);
384 MATCH_TYPE_AND_ENUM(uint32, DT_UINT32);
385 MATCH_TYPE_AND_ENUM(uint16, DT_UINT16);
386 MATCH_TYPE_AND_ENUM(uint8, DT_UINT8);
387 MATCH_TYPE_AND_ENUM(int16, DT_INT16);
388 MATCH_TYPE_AND_ENUM(int8, DT_INT8);
389 MATCH_TYPE_AND_ENUM(tstring, DT_STRING);
390 MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64);
391 MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128);
392 MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
393 MATCH_TYPE_AND_ENUM(qint8, DT_QINT8);
394 MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8);
395 MATCH_TYPE_AND_ENUM(qint16, DT_QINT16);
396 MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16);
397 MATCH_TYPE_AND_ENUM(qint32, DT_QINT32);
398 MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16);
399 MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF);
400 MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE);
401 MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT);
402 
403 template <>
404 struct DataTypeToEnum<long> {
405   static DataType v() { return value; }
406   static DataType ref() { return MakeRefType(value); }
407   static constexpr DataType value = sizeof(long) == 4 ? DT_INT32 : DT_INT64;
408 };
409 template <>
410 struct IsValidDataType<long> {
411   static constexpr bool value = true;
412 };
413 template <>
414 struct EnumToDataType<DT_INT64> {
415   typedef tensorflow::int64 Type;
416 };
417 
418 template <>
419 struct DataTypeToEnum<unsigned long> {
420   static DataType v() { return value; }
421   static DataType ref() { return MakeRefType(value); }
422   static constexpr DataType value =
423       sizeof(unsigned long) == 4 ? DT_UINT32 : DT_UINT64;
424 };
425 template <>
426 struct IsValidDataType<unsigned long> {
427   static constexpr bool value = true;
428 };
429 template <>
430 struct EnumToDataType<DT_UINT64> {
431   typedef tensorflow::uint64 Type;
432 };
433 
434 template <>
435 struct DataTypeToEnum<long long> {
436   static DataType v() { return DT_INT64; }
437   static DataType ref() { return MakeRefType(DT_INT64); }
438   static constexpr DataType value = DT_INT64;
439 };
440 template <>
441 struct IsValidDataType<long long> {
442   static constexpr bool value = true;
443 };
444 
445 template <>
446 struct DataTypeToEnum<unsigned long long> {
447   static DataType v() { return DT_UINT64; }
448   static DataType ref() { return MakeRefType(DT_UINT64); }
449   static constexpr DataType value = DT_UINT64;
450 };
451 template <>
452 struct IsValidDataType<unsigned long long> {
453   static constexpr bool value = true;
454 };
455 
456 #undef MATCH_TYPE_AND_ENUM
457 
458 // All types not specialized are marked invalid.
459 template <class T>
460 struct IsValidDataType {
461   static constexpr bool value = false;
462 };
463 
464 // Extra validity checking; not part of public API.
465 static_assert(IsValidDataType<int64>::value, "Incorrect impl for int64");
466 static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32");
467 
468 // TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying
469 // is_simple<T> in tensor.cc (and possible choose a more general name?)
470 constexpr DataTypeSet kDataTypesCanUseMemcpy =
471     ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT32) |
472     ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) |
473     ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) |
474     ToSet(DT_UINT64) | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) |
475     ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) |
476     ToSet(DT_BFLOAT16) | ToSet(DT_HALF);
477 inline bool DataTypeCanUseMemcpy(DataType dt) {
478   return kDataTypesCanUseMemcpy.Contains(dt);
479 }
480 
481 // Returns true iff 'dt' is a real, non-quantized floating point type.
482 constexpr DataTypeSet kDataTypeIsFloating =
483     ToSet(DT_HALF) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE);
484 inline bool DataTypeIsFloating(DataType dt) {
485   return kDataTypeIsFloating.Contains(dt);
486 }
487 
488 // Returns true iff 'dt' is a complex type.
489 constexpr DataTypeSet kDataTypeIsComplex =
490     ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128);
491 inline bool DataTypeIsComplex(DataType dt) {
492   return kDataTypeIsComplex.Contains(dt);
493 }
494 
495 inline bool DataTypeIsQuantized(DataType dt) {
496   return kQuantizedTypes.Contains(dt);
497 }
498 
499 // Is the dtype nonquantized integral?
500 constexpr DataTypeSet kDataTypeIsInteger =
501     ToSet(DT_INT8) | ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_UINT16) |
502     ToSet(DT_INT32) | ToSet(DT_UINT32) | ToSet(DT_INT64) | ToSet(DT_UINT64);
503 inline bool DataTypeIsInteger(DataType dt) {
504   return kDataTypeIsInteger.Contains(dt);
505 }
506 
507 // Is the dtype a signed integral type?
508 constexpr DataTypeSet kDataTypeIsSigned =
509     ToSet(DT_INT8) | ToSet(DT_INT16) | ToSet(DT_INT32) | ToSet(DT_INT64);
510 inline bool DataTypeIsSigned(DataType dt) {
511   return kDataTypeIsSigned.Contains(dt);
512 }
513 
514 // Is the dtype an unsigned integral type?
515 constexpr DataTypeSet kDataTypeIsUnsigned =
516     ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT32) | ToSet(DT_UINT64);
517 inline bool DataTypeIsUnsigned(DataType dt) {
518   return kDataTypeIsUnsigned.Contains(dt);
519 }
520 
521 // Returns a 0 on failure
522 int DataTypeSize(DataType dt);
523 
524 // Returns HOST_MEMORY if `dtype` is always on host or is a DT_INT32,
525 // DEVICE_MEMORY otherwise.
526 MemoryType MTypeFromDType(const DataType dtype);
527 
528 // Returns HOST_MEMORY if `dtype` is always on host, DEVICE_MEMORY otherwise.
529 // The reason we have MTypeFromDType() and MTypeFromDTypeIntsOnDevice(): for
530 // GPUs, we would like to keep int operations on host for performance concerns.
531 // But for TPUs (and other devices), int operations are placed on device.
532 MemoryType MTypeFromDTypeIntsOnDevice(const DataType dtype);
533 
534 // Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE.
535 // For DT_RESOURCE, the handle always sits on host (even if the underlying
536 // object has device-allocated resources).
537 bool DataTypeAlwaysOnHost(DataType dt);
538 
539 // FullType implementation.
540 
541 // Reference container for a type definition. These values are usually interned.
542 // These containers admit a notion of ordering for efficient access. The
543 // ordering has no semantic otherwise.
544 struct TypeRef {
545   std::shared_ptr<FullTypeDef> full_type;
546 
547   bool operator==(const TypeRef& other) const {
548     // TODO(mdan): This should be more efficient.
549     return full_type->SerializeAsString() ==
550            other.full_type->SerializeAsString();
551   }
552   bool operator<(const TypeRef& other) const {
553     return full_type->SerializeAsString() <
554            other.full_type->SerializeAsString();
555   }
556 };
557 
558 struct TypeHasher {
559   std::size_t operator()(const TypeRef& k) const {
560     return std::hash<std::string>()(k.full_type->SerializeAsString());
561   }
562 };
563 
564 // Maps a legacy DType proto enum to an equivalent FullType Tensor.
565 void map_dtype_to_tensor(const DataType& dtype, FullTypeDef* t);
566 
567 }  // namespace tensorflow
568 
569 #endif  // TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
570