1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TYPES_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_TYPES_H_ 18 19 #include <map> 20 #include <set> 21 #include <string> 22 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 // Disable clang-format to prevent 'FixedPoint' header from being included 25 // before 'Tensor' header on which it depends. 26 // clang-format off 27 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" 28 // clang-format on 29 #include "tensorflow/core/framework/bfloat16.h" 30 #include "tensorflow/core/framework/numeric_types.h" 31 #include "tensorflow/core/framework/resource_handle.h" 32 #include "tensorflow/core/framework/types.pb.h" 33 #include "tensorflow/core/lib/core/stringpiece.h" 34 #include "tensorflow/core/lib/gtl/array_slice.h" 35 #include "tensorflow/core/lib/gtl/inlined_vector.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/platform/types.h" 38 39 namespace tensorflow { 40 41 class Variant; 42 43 // MemoryType is used to describe whether input or output Tensors of 44 // an OpKernel should reside in "Host memory" (e.g., CPU memory) or 45 // "Device" Memory (CPU memory for CPU devices, GPU memory for GPU 46 // devices). 47 enum MemoryType { 48 DEVICE_MEMORY = 0, 49 HOST_MEMORY = 1, 50 }; 51 52 // A DeviceType is just a string, but we wrap it up in a class to give 53 // some type checking as we're passing these around 54 class DeviceType { 55 public: DeviceType(const char * type)56 DeviceType(const char* type) // NOLINT(runtime/explicit) 57 : type_(type) {} 58 DeviceType(StringPiece type)59 explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {} 60 type()61 const char* type() const { return type_.c_str(); } type_string()62 const std::string& type_string() const { return type_; } 63 64 bool operator<(const DeviceType& other) const; 65 bool operator==(const DeviceType& other) const; 66 bool operator!=(const DeviceType& other) const { return !(*this == other); } 67 68 private: 69 std::string type_; 70 }; 71 std::ostream& operator<<(std::ostream& os, const DeviceType& d); 72 73 // Convenient constants that can be passed to a DeviceType constructor 74 TF_EXPORT extern const char* const DEVICE_DEFAULT; // "DEFAULT" 75 TF_EXPORT extern const char* const DEVICE_CPU; // "CPU" 76 TF_EXPORT extern const char* const DEVICE_GPU; // "GPU" 77 TF_EXPORT extern const char* const DEVICE_TPU_SYSTEM; // "TPU_SYSTEM" 78 79 template <typename Device> 80 struct DeviceName {}; 81 82 template <> 83 struct DeviceName<Eigen::ThreadPoolDevice> { 84 static const std::string value; 85 }; 86 87 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ 88 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) 89 template <> 90 struct DeviceName<Eigen::GpuDevice> { 91 static const std::string value; 92 }; 93 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 94 95 96 typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector; 97 typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice; 98 99 typedef gtl::InlinedVector<DataType, 4> DataTypeVector; 100 typedef gtl::ArraySlice<DataType> DataTypeSlice; 101 102 typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector; 103 typedef gtl::InlinedVector<std::pair<DeviceType, int32>, 4> 104 PrioritizedDeviceTypeVector; 105 106 // Convert the enums to strings for errors: 107 std::string DataTypeString(DataType dtype); 108 std::string DeviceTypeString(const DeviceType& device_type); 109 std::string DataTypeSliceString(const DataTypeSlice dtypes); 110 inline std::string DataTypeVectorString(const DataTypeVector& dtypes) { 111 return DataTypeSliceString(dtypes); 112 } 113 114 // DataTypeSet represents a set of DataType values as a simple and efficient 115 // bit mask. Note that DataTypeSet cannot represent all DataType values; it 116 // cannot represent any of the DT_*_REF values. 117 class DataTypeSet { 118 private: 119 const uint32 mask_; 120 121 static constexpr uint32 kNumBits = 32; 122 123 public: 124 constexpr DataTypeSet(const DataTypeSet& other) : mask_(other.mask_) {} 125 explicit constexpr DataTypeSet(uint32 mask) : mask_(mask) {} 126 127 constexpr bool Contains(DataType dt) const { 128 return (static_cast<uint32>(dt) < kNumBits) && 129 ((mask_ >> static_cast<uint32>(dt)) & 1u) != 0u; 130 } 131 132 class Iterator { 133 const DataTypeSet& set_; 134 uint32 pos_; 135 136 public: 137 Iterator(const DataTypeSet& set, uint32 pos) : set_(set), pos_(pos) { 138 DCHECK_LE(pos, kNumBits); 139 } 140 DataType operator*() const { return static_cast<DataType>(pos_); } 141 Iterator& operator++() { 142 ++pos_; 143 DCHECK_LE(pos_, kNumBits); 144 if (pos_ < kNumBits) { 145 uint32 remaining_mask = set_.mask_ >> pos_; 146 if (remaining_mask != 0u) { 147 pos_ += ctz_uint32(remaining_mask); 148 } 149 } 150 DCHECK_LE(pos_, kNumBits); 151 return *this; 152 } 153 bool operator==(const Iterator& other) const { return pos_ == other.pos_; } 154 bool operator!=(const Iterator& other) const { return !(*this == other); } 155 size_t operator-(const Iterator& other) const { 156 return this->pos_ - other.pos_; 157 } 158 }; 159 160 static uint32 ctz_uint32(uint32 x) { 161 DCHECK_NE(x, 0u); 162 #ifdef __GNUC__ 163 return __builtin_ctz(x); 164 #else 165 uint32 n = 0u; 166 while ((x & 1u) == 0u) { 167 x >>= 1; 168 ++n; 169 } 170 return n; 171 #endif 172 } 173 174 static uint32 clz_uint32(uint32 x) { 175 DCHECK_NE(x, 0u); 176 #ifdef __GNUC__ 177 return __builtin_clz(x); 178 #else 179 uint32 n = 0u; 180 while ((x >> (kNumBits - 1u)) == 0u) { 181 x <<= 1; 182 ++n; 183 } 184 return n; 185 #endif 186 } 187 188 Iterator begin() const { 189 // The begin position is the index of the first bit set to 1 in the entire 190 // bit mask. If there are no bits set to 1, then the index is 0. 191 if (mask_ != 0) { 192 return Iterator(*this, ctz_uint32(mask_)); 193 } 194 // The set is empty. 195 return Iterator(*this, 0); 196 } 197 198 Iterator end() const { 199 // The end position is the index of the highest bit that is set, plus 1. 200 // If there are no bits set to 1, then the index is 0. 201 if (mask_ != 0) { 202 return Iterator(*this, kNumBits - clz_uint32(mask_)); 203 } 204 // The set is empty. 205 return Iterator(*this, 0); 206 } 207 208 size_t size() const { 209 #if defined(__GNUC__) 210 return __builtin_popcount(mask_); 211 #else 212 size_t n = 0; 213 uint32 x = mask_; 214 while (x > 0) { 215 n += x & 1u; 216 x >>= 1; 217 } 218 return n; 219 #endif 220 } 221 222 constexpr DataTypeSet operator|(const DataTypeSet& other) const { 223 return DataTypeSet(mask_ | other.mask_); 224 } 225 }; 226 227 // If "sp" names a valid type, store it in "*dt" and return true. Otherwise, 228 // return false. 229 bool DataTypeFromString(StringPiece sp, DataType* dt); 230 231 constexpr inline DataTypeSet ToSet(DataType dt) { 232 return DataTypeSet(1u << static_cast<uint32>(dt)); 233 } 234 235 // DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc. 236 enum { kDataTypeRefOffset = 100 }; 237 inline bool IsRefType(DataType dtype) { 238 return dtype > static_cast<DataType>(kDataTypeRefOffset); 239 } 240 inline DataType MakeRefType(DataType dtype) { 241 DCHECK(!IsRefType(dtype)); 242 return static_cast<DataType>(dtype + kDataTypeRefOffset); 243 } 244 inline DataType RemoveRefType(DataType dtype) { 245 DCHECK(IsRefType(dtype)); 246 return static_cast<DataType>(dtype - kDataTypeRefOffset); 247 } 248 inline DataType BaseType(DataType dtype) { 249 return IsRefType(dtype) ? RemoveRefType(dtype) : dtype; 250 } 251 252 // Returns true if the actual type is the same as or ref of the expected type. 253 inline bool TypesCompatible(DataType expected, DataType actual) { 254 return expected == actual || expected == BaseType(actual); 255 } 256 257 // Does not include _ref types. 258 constexpr DataTypeSet kAllTypes = 259 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT8) | 260 ToSet(DT_INT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | ToSet(DT_STRING) | 261 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | 262 ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | 263 ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_RESOURCE) | 264 ToSet(DT_VARIANT) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | 265 ToSet(DT_BFLOAT16); 266 inline const DataTypeSet& AllTypes() { return kAllTypes; } 267 268 #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) 269 270 // Types that support '<' and '>'. 271 constexpr DataTypeSet kRealNumberTypes = 272 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | 273 ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_INT8) | ToSet(DT_UINT16) | 274 ToSet(DT_HALF) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); 275 inline const DataTypeSet RealNumberTypes() { return kRealNumberTypes; } 276 277 // Return the list of all numeric types. 278 // Includes complex and quantized types. 279 // NOTE: On Android, we only include the float and int32 types for now. 280 const DataTypeSet kNumberTypes = 281 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT64) | ToSet(DT_INT32) | 282 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | 283 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_QINT8) | 284 ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_UINT32) | 285 ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); 286 inline const DataTypeSet& NumberTypes() { return kNumberTypes; } 287 288 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 289 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 290 ToSet(DT_QINT32); 291 inline const DataTypeSet& QuantizedTypes() { return kQuantizedTypes; } 292 293 // Types that support '<' and '>', including quantized types. 294 const DataTypeSet kRealAndQuantizedTypes = 295 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | 296 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | 297 ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 298 ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_BFLOAT16); 299 inline const DataTypeSet& RealAndQuantizedTypes() { 300 return kRealAndQuantizedTypes; 301 } 302 303 #elif defined(__ANDROID_TYPES_FULL__) 304 305 constexpr DataTypeSet kRealNumberTypes = 306 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_HALF); 307 inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; } 308 309 constexpr DataTypeSet kNumberTypes = 310 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) | 311 ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF); 312 inline DataTypeSet NumberTypes() { return kNumberTypes; } 313 314 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 315 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 316 ToSet(DT_QINT32); 317 inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; } 318 319 constexpr DataTypeSet kRealAndQuantizedTypes = 320 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) | 321 ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | 322 ToSet(DT_HALF); 323 inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; } 324 325 #else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) 326 327 constexpr DataTypeSet kRealNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32); 328 inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; } 329 330 constexpr DataTypeSet kNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32) | 331 ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 332 ToSet(DT_QINT32); 333 inline DataTypeSet NumberTypes() { return kNumberTypes; } 334 335 constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 336 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | 337 ToSet(DT_QINT32); 338 inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; } 339 340 constexpr DataTypeSet kRealAndQuantizedTypes = 341 ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 342 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32); 343 inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; } 344 345 #endif // defined(IS_MOBILE_PLATFORM) 346 347 // Validates type T for whether it is a supported DataType. 348 template <class T> 349 struct IsValidDataType; 350 351 // DataTypeToEnum<T>::v() and DataTypeToEnum<T>::value are the DataType 352 // constants for T, e.g. DataTypeToEnum<float>::v() is DT_FLOAT. 353 template <class T> 354 struct DataTypeToEnum { 355 static_assert(IsValidDataType<T>::value, "Specified Data Type not supported"); 356 }; // Specializations below 357 358 // EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g. 359 // EnumToDataType<DT_FLOAT>::Type is float. 360 template <DataType VALUE> 361 struct EnumToDataType {}; // Specializations below 362 363 // Template specialization for both DataTypeToEnum and EnumToDataType. 364 #define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \ 365 template <> \ 366 struct DataTypeToEnum<TYPE> { \ 367 static DataType v() { return ENUM; } \ 368 static DataType ref() { return MakeRefType(ENUM); } \ 369 static constexpr DataType value = ENUM; \ 370 }; \ 371 template <> \ 372 struct IsValidDataType<TYPE> { \ 373 static constexpr bool value = true; \ 374 }; \ 375 template <> \ 376 struct EnumToDataType<ENUM> { \ 377 typedef TYPE Type; \ 378 } 379 380 MATCH_TYPE_AND_ENUM(float, DT_FLOAT); 381 MATCH_TYPE_AND_ENUM(double, DT_DOUBLE); 382 MATCH_TYPE_AND_ENUM(int32, DT_INT32); 383 MATCH_TYPE_AND_ENUM(uint32, DT_UINT32); 384 MATCH_TYPE_AND_ENUM(uint16, DT_UINT16); 385 MATCH_TYPE_AND_ENUM(uint8, DT_UINT8); 386 MATCH_TYPE_AND_ENUM(int16, DT_INT16); 387 MATCH_TYPE_AND_ENUM(int8, DT_INT8); 388 MATCH_TYPE_AND_ENUM(tstring, DT_STRING); 389 MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64); 390 MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128); 391 MATCH_TYPE_AND_ENUM(bool, DT_BOOL); 392 MATCH_TYPE_AND_ENUM(qint8, DT_QINT8); 393 MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8); 394 MATCH_TYPE_AND_ENUM(qint16, DT_QINT16); 395 MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16); 396 MATCH_TYPE_AND_ENUM(qint32, DT_QINT32); 397 MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16); 398 MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF); 399 MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE); 400 MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT); 401 402 template <> 403 struct DataTypeToEnum<long> { 404 static DataType v() { return value; } 405 static DataType ref() { return MakeRefType(value); } 406 static constexpr DataType value = sizeof(long) == 4 ? DT_INT32 : DT_INT64; 407 }; 408 template <> 409 struct IsValidDataType<long> { 410 static constexpr bool value = true; 411 }; 412 template <> 413 struct EnumToDataType<DT_INT64> { 414 typedef tensorflow::int64 Type; 415 }; 416 417 template <> 418 struct DataTypeToEnum<unsigned long> { 419 static DataType v() { return value; } 420 static DataType ref() { return MakeRefType(value); } 421 static constexpr DataType value = 422 sizeof(unsigned long) == 4 ? DT_UINT32 : DT_UINT64; 423 }; 424 template <> 425 struct IsValidDataType<unsigned long> { 426 static constexpr bool value = true; 427 }; 428 template <> 429 struct EnumToDataType<DT_UINT64> { 430 typedef tensorflow::uint64 Type; 431 }; 432 433 template <> 434 struct DataTypeToEnum<long long> { 435 static DataType v() { return DT_INT64; } 436 static DataType ref() { return MakeRefType(DT_INT64); } 437 static constexpr DataType value = DT_INT64; 438 }; 439 template <> 440 struct IsValidDataType<long long> { 441 static constexpr bool value = true; 442 }; 443 444 template <> 445 struct DataTypeToEnum<unsigned long long> { 446 static DataType v() { return DT_UINT64; } 447 static DataType ref() { return MakeRefType(DT_UINT64); } 448 static constexpr DataType value = DT_UINT64; 449 }; 450 template <> 451 struct IsValidDataType<unsigned long long> { 452 static constexpr bool value = true; 453 }; 454 455 #undef MATCH_TYPE_AND_ENUM 456 457 // All types not specialized are marked invalid. 458 template <class T> 459 struct IsValidDataType { 460 static constexpr bool value = false; 461 }; 462 463 // Extra validity checking; not part of public API. 464 static_assert(IsValidDataType<int64>::value, "Incorrect impl for int64"); 465 static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32"); 466 467 // TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying 468 // is_simple<T> in tensor.cc (and possible choose a more general name?) 469 constexpr DataTypeSet kDataTypesCanUseMemcpy = 470 ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT32) | 471 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | 472 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | 473 ToSet(DT_UINT64) | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | 474 ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | 475 ToSet(DT_BFLOAT16) | ToSet(DT_HALF); 476 inline bool DataTypeCanUseMemcpy(DataType dt) { 477 return kDataTypesCanUseMemcpy.Contains(dt); 478 } 479 480 // Returns true iff 'dt' is a real, non-quantized floating point type. 481 constexpr DataTypeSet kDataTypeIsFloating = 482 ToSet(DT_HALF) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE); 483 inline bool DataTypeIsFloating(DataType dt) { 484 return kDataTypeIsFloating.Contains(dt); 485 } 486 487 // Returns true iff 'dt' is a complex type. 488 constexpr DataTypeSet kDataTypeIsComplex = 489 ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128); 490 inline bool DataTypeIsComplex(DataType dt) { 491 return kDataTypeIsComplex.Contains(dt); 492 } 493 494 inline bool DataTypeIsQuantized(DataType dt) { 495 return kQuantizedTypes.Contains(dt); 496 } 497 498 // Is the dtype nonquantized integral? 499 constexpr DataTypeSet kDataTypeIsInteger = 500 ToSet(DT_INT8) | ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_UINT16) | 501 ToSet(DT_INT32) | ToSet(DT_UINT32) | ToSet(DT_INT64) | ToSet(DT_UINT64); 502 inline bool DataTypeIsInteger(DataType dt) { 503 return kDataTypeIsInteger.Contains(dt); 504 } 505 506 // Is the dtype a signed integral type? 507 constexpr DataTypeSet kDataTypeIsSigned = 508 ToSet(DT_INT8) | ToSet(DT_INT16) | ToSet(DT_INT32) | ToSet(DT_INT64); 509 inline bool DataTypeIsSigned(DataType dt) { 510 return kDataTypeIsSigned.Contains(dt); 511 } 512 513 // Is the dtype an unsigned integral type? 514 constexpr DataTypeSet kDataTypeIsUnsigned = 515 ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT32) | ToSet(DT_UINT64); 516 inline bool DataTypeIsUnsigned(DataType dt) { 517 return kDataTypeIsUnsigned.Contains(dt); 518 } 519 520 // Returns a 0 on failure 521 int DataTypeSize(DataType dt); 522 523 // Returns HOST_MEMORY if `dtype` is always on host or is a DT_INT32, 524 // DEVICE_MEMORY otherwise. 525 MemoryType MTypeFromDType(const DataType dtype); 526 527 // Returns HOST_MEMORY if `dtype` is always on host, DEVICE_MEMORY otherwise. 528 // The reason we have MTypeFromDType() and MTypeFromDTypeIntsOnDevice(): for 529 // GPUs, we would like to keep int operations on host for performance concerns. 530 // But for TPUs (and other devices), int operations are placed on device. 531 MemoryType MTypeFromDTypeIntsOnDevice(const DataType dtype); 532 533 // Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE. 534 // For DT_RESOURCE, the handle always sits on host (even if the underlying 535 // object has device-allocated resources). 536 bool DataTypeAlwaysOnHost(DataType dt); 537 538 } // namespace tensorflow 539 540 #endif // TENSORFLOW_CORE_FRAMEWORK_TYPES_H_ 541