1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_ 17 #define TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_ 18 19 #include "tensorflow/stream_executor/data_type.h" 20 #include "tensorflow/stream_executor/device_memory.h" 21 #include "tensorflow/stream_executor/platform/logging.h" 22 23 namespace stream_executor { 24 25 // Allows to represent a value that is either a host scalar or a scalar stored 26 // on the GPU device. 27 // See also the specialization for ElemT=void below. 28 template <typename ElemT> 29 class HostOrDeviceScalar { 30 public: 31 // Not marked as explicit because when using this constructor, we usually want 32 // to set this to a compile-time constant. HostOrDeviceScalar(ElemT value)33 HostOrDeviceScalar(ElemT value) : value_(value), is_pointer_(false) {} HostOrDeviceScalar(const DeviceMemory<ElemT> & pointer)34 explicit HostOrDeviceScalar(const DeviceMemory<ElemT>& pointer) 35 : pointer_(pointer), is_pointer_(true) { 36 CHECK_EQ(1, pointer.ElementCount()); 37 } 38 is_pointer()39 bool is_pointer() const { return is_pointer_; } pointer()40 const DeviceMemory<ElemT>& pointer() const { 41 CHECK(is_pointer()); 42 return pointer_; 43 } value()44 const ElemT& value() const { 45 CHECK(!is_pointer()); 46 return value_; 47 } 48 49 private: 50 union { 51 ElemT value_; 52 DeviceMemory<ElemT> pointer_; 53 }; 54 bool is_pointer_; 55 }; 56 57 // Specialization for wrapping a dynamically-typed value (via type erasure). 58 template <> 59 class HostOrDeviceScalar<void> { 60 public: 61 using DataType = dnn::DataType; 62 63 // Constructors not marked as explicit because when using this constructor, we 64 // usually want to set this to a compile-time constant. 65 66 // NOLINTNEXTLINE google-explicit-constructor HostOrDeviceScalar(float value)67 HostOrDeviceScalar(float value) 68 : float_(value), is_pointer_(false), dtype_(DataType::kFloat) {} 69 // NOLINTNEXTLINE google-explicit-constructor HostOrDeviceScalar(double value)70 HostOrDeviceScalar(double value) 71 : double_(value), is_pointer_(false), dtype_(DataType::kDouble) {} 72 // NOLINTNEXTLINE google-explicit-constructor HostOrDeviceScalar(Eigen::half value)73 HostOrDeviceScalar(Eigen::half value) 74 : half_(value), is_pointer_(false), dtype_(DataType::kHalf) {} 75 // NOLINTNEXTLINE google-explicit-constructor HostOrDeviceScalar(int8 value)76 HostOrDeviceScalar(int8 value) 77 : int8_(value), is_pointer_(false), dtype_(DataType::kInt8) {} 78 // NOLINTNEXTLINE google-explicit-constructor HostOrDeviceScalar(int32 value)79 HostOrDeviceScalar(int32 value) 80 : int32_(value), is_pointer_(false), dtype_(DataType::kInt32) {} 81 // NOLINTNEXTLINE google-explicit-constructor HostOrDeviceScalar(std::complex<float> value)82 HostOrDeviceScalar(std::complex<float> value) 83 : complex_float_(value), 84 is_pointer_(false), 85 dtype_(DataType::kComplexFloat) {} 86 // NOLINTNEXTLINE google-explicit-constructor HostOrDeviceScalar(std::complex<double> value)87 HostOrDeviceScalar(std::complex<double> value) 88 : complex_double_(value), 89 is_pointer_(false), 90 dtype_(DataType::kComplexDouble) {} 91 template <typename T> HostOrDeviceScalar(const DeviceMemory<T> & pointer)92 explicit HostOrDeviceScalar(const DeviceMemory<T>& pointer) 93 : pointer_(pointer), 94 is_pointer_(true), 95 dtype_(dnn::ToDataType<T>::value) { 96 CHECK_EQ(1, pointer.ElementCount()); 97 } 98 // Construct from statically-typed version. 99 template <typename T, typename std::enable_if<!std::is_same<T, void>::value, 100 int>::type = 0> 101 // NOLINTNEXTLINE google-explicit-constructor HostOrDeviceScalar(const HostOrDeviceScalar<T> & other)102 HostOrDeviceScalar(const HostOrDeviceScalar<T>& other) { 103 if (other.is_pointer()) { 104 *this = HostOrDeviceScalar(other.pointer()); 105 } else { 106 *this = HostOrDeviceScalar(other.value()); 107 } 108 } 109 is_pointer()110 bool is_pointer() const { return is_pointer_; } 111 template <typename T> pointer()112 const DeviceMemory<T>& pointer() const { 113 CHECK(is_pointer()); 114 CHECK(dtype_ == dnn::ToDataType<T>::value); 115 return pointer_; 116 } 117 template <typename T> value()118 const T& value() const { 119 CHECK(!is_pointer()); 120 CHECK(dtype_ == dnn::ToDataType<T>::value); 121 return value_impl<T>(); 122 } opaque_pointer()123 const DeviceMemoryBase& opaque_pointer() const { 124 CHECK(is_pointer()); 125 return pointer_; 126 } opaque_value()127 const void* opaque_value() const { 128 CHECK(!is_pointer()); 129 switch (dtype_) { 130 case DataType::kFloat: 131 return &float_; 132 case DataType::kDouble: 133 return &double_; 134 case DataType::kHalf: 135 return &half_; 136 case DataType::kInt8: 137 return &int8_; 138 case DataType::kInt32: 139 return &int32_; 140 case DataType::kComplexFloat: 141 return &complex_float_; 142 case DataType::kComplexDouble: 143 return &complex_double_; 144 default: 145 return nullptr; 146 } 147 } data_type()148 DataType data_type() const { return dtype_; } 149 150 private: 151 template <typename T> 152 const T& value_impl() const; 153 154 union { 155 float float_; 156 double double_; 157 Eigen::half half_; 158 int8 int8_; 159 int32 int32_; 160 std::complex<float> complex_float_; 161 std::complex<double> complex_double_; 162 DeviceMemoryBase pointer_; 163 }; 164 bool is_pointer_; 165 DataType dtype_; 166 }; 167 168 template <> 169 inline const float& HostOrDeviceScalar<void>::value_impl<float>() const { 170 return float_; 171 } 172 173 template <> 174 inline const double& HostOrDeviceScalar<void>::value_impl<double>() const { 175 return double_; 176 } 177 178 template <> 179 inline const Eigen::half& HostOrDeviceScalar<void>::value_impl<Eigen::half>() 180 const { 181 return half_; 182 } 183 184 template <> 185 inline const int8& HostOrDeviceScalar<void>::value_impl<int8>() const { 186 return int8_; 187 } 188 189 template <> 190 inline const int32& HostOrDeviceScalar<void>::value_impl<int32>() const { 191 return int32_; 192 } 193 194 template <> 195 inline const std::complex<float>& 196 HostOrDeviceScalar<void>::value_impl<std::complex<float>>() const { 197 return complex_float_; 198 } 199 200 template <> 201 inline const std::complex<double>& 202 HostOrDeviceScalar<void>::value_impl<std::complex<double>>() const { 203 return complex_double_; 204 } 205 206 } // namespace stream_executor 207 #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_ 208