1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "rocm/include/hiprand/hiprand.h"
17 #include "tensorflow/stream_executor/device_memory.h"
18 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
19 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
20 #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
21 #include "tensorflow/stream_executor/gpu/gpu_rng.h"
22 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
23 #include "tensorflow/stream_executor/lib/env.h"
24 #include "tensorflow/stream_executor/lib/initialize.h"
25 #include "tensorflow/stream_executor/lib/status.h"
26 #include "tensorflow/stream_executor/platform/dso_loader.h"
27 #include "tensorflow/stream_executor/platform/logging.h"
28 #include "tensorflow/stream_executor/rng.h"
29 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
30
31 // Formats hiprandStatus_t to output prettified values into a log stream.
operator <<(std::ostream & in,const hiprandStatus_t & status)32 std::ostream& operator<<(std::ostream& in, const hiprandStatus_t& status) {
33 #define OSTREAM_HIPRAND_STATUS(__name) \
34 case HIPRAND_STATUS_##__name: \
35 in << "HIPRAND_STATUS_" #__name; \
36 return in;
37
38 switch (status) {
39 OSTREAM_HIPRAND_STATUS(SUCCESS)
40 OSTREAM_HIPRAND_STATUS(VERSION_MISMATCH)
41 OSTREAM_HIPRAND_STATUS(NOT_INITIALIZED)
42 OSTREAM_HIPRAND_STATUS(ALLOCATION_FAILED)
43 OSTREAM_HIPRAND_STATUS(TYPE_ERROR)
44 OSTREAM_HIPRAND_STATUS(OUT_OF_RANGE)
45 OSTREAM_HIPRAND_STATUS(LENGTH_NOT_MULTIPLE)
46 OSTREAM_HIPRAND_STATUS(LAUNCH_FAILURE)
47 OSTREAM_HIPRAND_STATUS(PREEXISTING_FAILURE)
48 OSTREAM_HIPRAND_STATUS(INITIALIZATION_FAILED)
49 OSTREAM_HIPRAND_STATUS(ARCH_MISMATCH)
50 OSTREAM_HIPRAND_STATUS(INTERNAL_ERROR)
51 default:
52 in << "hiprandStatus_t(" << static_cast<int>(status) << ")";
53 return in;
54 }
55 }
56
57 namespace stream_executor {
58 namespace gpu {
59
60 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin);
61
62 namespace wrap {
63
64 #ifdef PLATFORM_GOOGLE
65
66 #define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
67 struct WrapperShim__##__name { \
68 template <typename... Args> \
69 hiprandStatus_t operator()(GpuExecutor* parent, Args... args) { \
70 gpu::ScopedActivateExecutorContext sac{parent}; \
71 return ::__name(args...); \
72 } \
73 } __name;
74
75 #else
76
77 #define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
78 struct DynLoadShim__##__name { \
79 static const char* kName; \
80 using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
81 static void* GetDsoHandle() { \
82 auto s = internal::CachedDsoLoader::GetRocrandDsoHandle(); \
83 return s.ValueOrDie(); \
84 } \
85 static FuncPtrT LoadOrDie() { \
86 void* f; \
87 auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
88 kName, &f); \
89 CHECK(s.ok()) << "could not find " << kName \
90 << " in rocrand DSO; dlerror: " << s.error_message(); \
91 return reinterpret_cast<FuncPtrT>(f); \
92 } \
93 static FuncPtrT DynLoad() { \
94 static FuncPtrT f = LoadOrDie(); \
95 return f; \
96 } \
97 template <typename... Args> \
98 hiprandStatus operator()(GpuExecutor* parent, Args... args) { \
99 gpu::ScopedActivateExecutorContext sac{parent}; \
100 return DynLoad()(args...); \
101 } \
102 } __name; \
103 const char* DynLoadShim__##__name::kName = #__name;
104
105 #endif
106
107 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator);
108 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator);
109 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream);
110 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniform);
111 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniformDouble);
112 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed);
113 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetGeneratorOffset);
114 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormal);
115 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormalDouble);
116
117 } // namespace wrap
118
GpuRng(GpuExecutor * parent)119 GpuRng::GpuRng(GpuExecutor* parent) : parent_(parent), rng_(nullptr) {}
120
~GpuRng()121 GpuRng::~GpuRng() {
122 if (rng_ != nullptr) {
123 wrap::hiprandDestroyGenerator(parent_, rng_);
124 }
125 }
126
Init()127 bool GpuRng::Init() {
128 absl::MutexLock lock{&mu_};
129 CHECK(rng_ == nullptr);
130
131 hiprandStatus_t ret =
132 wrap::hiprandCreateGenerator(parent_, &rng_, HIPRAND_RNG_PSEUDO_DEFAULT);
133 if (ret != HIPRAND_STATUS_SUCCESS) {
134 LOG(ERROR) << "failed to create random number generator: " << ret;
135 return false;
136 }
137
138 CHECK(rng_ != nullptr);
139 return true;
140 }
141
SetStream(Stream * stream)142 bool GpuRng::SetStream(Stream* stream) {
143 hiprandStatus_t ret =
144 wrap::hiprandSetStream(parent_, rng_, AsGpuStreamValue(stream));
145 if (ret != HIPRAND_STATUS_SUCCESS) {
146 LOG(ERROR) << "failed to set stream for random generation: " << ret;
147 return false;
148 }
149
150 return true;
151 }
152
153 // Returns true if std::complex stores its contents as two consecutive
154 // elements. Tests int, float and double, as the last two are independent
155 // specializations.
ComplexIsConsecutiveFloats()156 constexpr bool ComplexIsConsecutiveFloats() {
157 return sizeof(std::complex<int>) == 8 && sizeof(std::complex<float>) == 8 &&
158 sizeof(std::complex<double>) == 16;
159 }
160
161 template <typename T>
DoPopulateRandUniformInternal(Stream * stream,DeviceMemory<T> * v)162 bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v) {
163 absl::MutexLock lock{&mu_};
164 static_assert(ComplexIsConsecutiveFloats(),
165 "std::complex values are not stored as consecutive values");
166
167 if (!SetStream(stream)) {
168 return false;
169 }
170
171 // std::complex<T> is currently implemented as two consecutive T variables.
172 uint64 element_count = v->ElementCount();
173 if (std::is_same<T, std::complex<float>>::value ||
174 std::is_same<T, std::complex<double>>::value) {
175 element_count *= 2;
176 }
177
178 hiprandStatus_t ret;
179 if (std::is_same<T, float>::value ||
180 std::is_same<T, std::complex<float>>::value) {
181 ret = wrap::hiprandGenerateUniform(
182 parent_, rng_, reinterpret_cast<float*>(GpuMemoryMutable(v)),
183 element_count);
184 } else {
185 ret = wrap::hiprandGenerateUniformDouble(
186 parent_, rng_, reinterpret_cast<double*>(GpuMemoryMutable(v)),
187 element_count);
188 }
189 if (ret != HIPRAND_STATUS_SUCCESS) {
190 LOG(ERROR) << "failed to do uniform generation of " << v->ElementCount()
191 << " " << TypeString<T>() << "s at " << v->opaque() << ": "
192 << ret;
193 return false;
194 }
195
196 return true;
197 }
198
DoPopulateRandUniform(Stream * stream,DeviceMemory<float> * v)199 bool GpuRng::DoPopulateRandUniform(Stream* stream, DeviceMemory<float>* v) {
200 return DoPopulateRandUniformInternal(stream, v);
201 }
202
DoPopulateRandUniform(Stream * stream,DeviceMemory<double> * v)203 bool GpuRng::DoPopulateRandUniform(Stream* stream, DeviceMemory<double>* v) {
204 return DoPopulateRandUniformInternal(stream, v);
205 }
206
DoPopulateRandUniform(Stream * stream,DeviceMemory<std::complex<float>> * v)207 bool GpuRng::DoPopulateRandUniform(Stream* stream,
208 DeviceMemory<std::complex<float>>* v) {
209 return DoPopulateRandUniformInternal(stream, v);
210 }
211
DoPopulateRandUniform(Stream * stream,DeviceMemory<std::complex<double>> * v)212 bool GpuRng::DoPopulateRandUniform(Stream* stream,
213 DeviceMemory<std::complex<double>>* v) {
214 return DoPopulateRandUniformInternal(stream, v);
215 }
216
217 template <typename ElemT, typename FuncT>
DoPopulateRandGaussianInternal(Stream * stream,ElemT mean,ElemT stddev,DeviceMemory<ElemT> * v,FuncT func)218 bool GpuRng::DoPopulateRandGaussianInternal(Stream* stream, ElemT mean,
219 ElemT stddev,
220 DeviceMemory<ElemT>* v,
221 FuncT func) {
222 absl::MutexLock lock{&mu_};
223
224 if (!SetStream(stream)) {
225 return false;
226 }
227
228 uint64 element_count = v->ElementCount();
229 hiprandStatus_t ret =
230 func(parent_, rng_, GpuMemoryMutable(v), element_count, mean, stddev);
231
232 if (ret != HIPRAND_STATUS_SUCCESS) {
233 LOG(ERROR) << "failed to do gaussian generation of " << v->ElementCount()
234 << " floats at " << v->opaque() << ": " << ret;
235 return false;
236 }
237
238 return true;
239 }
240
DoPopulateRandGaussian(Stream * stream,float mean,float stddev,DeviceMemory<float> * v)241 bool GpuRng::DoPopulateRandGaussian(Stream* stream, float mean, float stddev,
242 DeviceMemory<float>* v) {
243 return DoPopulateRandGaussianInternal(stream, mean, stddev, v,
244 wrap::hiprandGenerateNormal);
245 }
246
DoPopulateRandGaussian(Stream * stream,double mean,double stddev,DeviceMemory<double> * v)247 bool GpuRng::DoPopulateRandGaussian(Stream* stream, double mean, double stddev,
248 DeviceMemory<double>* v) {
249 return DoPopulateRandGaussianInternal(stream, mean, stddev, v,
250 wrap::hiprandGenerateNormalDouble);
251 }
252
SetSeed(Stream * stream,const uint8 * seed,uint64 seed_bytes)253 bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
254 absl::MutexLock lock{&mu_};
255 CHECK(rng_ != nullptr);
256
257 if (!CheckSeed(seed, seed_bytes)) {
258 return false;
259 }
260
261 if (!SetStream(stream)) {
262 return false;
263 }
264
265 // Requires 8 bytes of seed data; checked in RngSupport::CheckSeed (above)
266 // (which itself requires 16 for API consistency with host RNG fallbacks).
267 hiprandStatus_t ret = wrap::hiprandSetPseudoRandomGeneratorSeed(
268 parent_, rng_, *(reinterpret_cast<const uint64*>(seed)));
269 if (ret != HIPRAND_STATUS_SUCCESS) {
270 LOG(ERROR) << "failed to set rng seed: " << ret;
271 return false;
272 }
273
274 ret = wrap::hiprandSetGeneratorOffset(parent_, rng_, 0);
275 if (ret != HIPRAND_STATUS_SUCCESS) {
276 LOG(ERROR) << "failed to reset rng position: " << ret;
277 return false;
278 }
279 return true;
280 }
281
282 } // namespace gpu
283
initialize_rocrand()284 void initialize_rocrand() {
285 auto rocRandAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
286 rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
287
288 if (!rocRandAlreadyRegistered) {
289 port::Status status =
290 PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
291 rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
292 [](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
293 gpu::GpuExecutor* rocm_executor =
294 dynamic_cast<gpu::GpuExecutor*>(parent);
295 if (rocm_executor == nullptr) {
296 LOG(ERROR)
297 << "Attempting to initialize an instance of the hipRAND "
298 << "support library with a non-ROCM StreamExecutor";
299 return nullptr;
300 }
301
302 gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
303 if (!rng->Init()) {
304 // Note: Init() will log a more specific error.
305 delete rng;
306 return nullptr;
307 }
308 return rng;
309 });
310
311 if (!status.ok()) {
312 LOG(ERROR) << "Unable to register rocRAND factory: "
313 << status.error_message();
314 }
315
316 PluginRegistry::Instance()->SetDefaultFactory(
317 rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
318 }
319 }
320
321 } // namespace stream_executor
322
323 REGISTER_MODULE_INITIALIZER(register_rocrand,
324 { stream_executor::initialize_rocrand(); });
325