• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "rocm/include/hiprand/hiprand.h"
17 #include "tensorflow/stream_executor/device_memory.h"
18 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
19 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
20 #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
21 #include "tensorflow/stream_executor/gpu/gpu_rng.h"
22 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
23 #include "tensorflow/stream_executor/lib/env.h"
24 #include "tensorflow/stream_executor/lib/initialize.h"
25 #include "tensorflow/stream_executor/lib/status.h"
26 #include "tensorflow/stream_executor/platform/dso_loader.h"
27 #include "tensorflow/stream_executor/platform/logging.h"
28 #include "tensorflow/stream_executor/rng.h"
29 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
30 
31 // Formats hiprandStatus_t to output prettified values into a log stream.
operator <<(std::ostream & in,const hiprandStatus_t & status)32 std::ostream& operator<<(std::ostream& in, const hiprandStatus_t& status) {
33 #define OSTREAM_HIPRAND_STATUS(__name) \
34   case HIPRAND_STATUS_##__name:        \
35     in << "HIPRAND_STATUS_" #__name;   \
36     return in;
37 
38   switch (status) {
39     OSTREAM_HIPRAND_STATUS(SUCCESS)
40     OSTREAM_HIPRAND_STATUS(VERSION_MISMATCH)
41     OSTREAM_HIPRAND_STATUS(NOT_INITIALIZED)
42     OSTREAM_HIPRAND_STATUS(ALLOCATION_FAILED)
43     OSTREAM_HIPRAND_STATUS(TYPE_ERROR)
44     OSTREAM_HIPRAND_STATUS(OUT_OF_RANGE)
45     OSTREAM_HIPRAND_STATUS(LENGTH_NOT_MULTIPLE)
46     OSTREAM_HIPRAND_STATUS(LAUNCH_FAILURE)
47     OSTREAM_HIPRAND_STATUS(PREEXISTING_FAILURE)
48     OSTREAM_HIPRAND_STATUS(INITIALIZATION_FAILED)
49     OSTREAM_HIPRAND_STATUS(ARCH_MISMATCH)
50     OSTREAM_HIPRAND_STATUS(INTERNAL_ERROR)
51     default:
52       in << "hiprandStatus_t(" << static_cast<int>(status) << ")";
53       return in;
54   }
55 }
56 
57 namespace stream_executor {
58 namespace gpu {
59 
60 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin);
61 
62 namespace wrap {
63 
64 #ifdef PLATFORM_GOOGLE
65 
66 #define STREAM_EXECUTOR_HIPRAND_WRAP(__name)                        \
67   struct WrapperShim__##__name {                                    \
68     template <typename... Args>                                     \
69     hiprandStatus_t operator()(GpuExecutor* parent, Args... args) { \
70       gpu::ScopedActivateExecutorContext sac{parent};               \
71       return ::__name(args...);                                     \
72     }                                                               \
73   } __name;
74 
75 #else
76 
77 #define STREAM_EXECUTOR_HIPRAND_WRAP(__name)                              \
78   struct DynLoadShim__##__name {                                          \
79     static const char* kName;                                             \
80     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
81     static void* GetDsoHandle() {                                         \
82       auto s = internal::CachedDsoLoader::GetRocrandDsoHandle();          \
83       return s.ValueOrDie();                                              \
84     }                                                                     \
85     static FuncPtrT LoadOrDie() {                                         \
86       void* f;                                                            \
87       auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
88                                                           kName, &f);     \
89       CHECK(s.ok()) << "could not find " << kName                         \
90                     << " in rocrand DSO; dlerror: " << s.error_message(); \
91       return reinterpret_cast<FuncPtrT>(f);                               \
92     }                                                                     \
93     static FuncPtrT DynLoad() {                                           \
94       static FuncPtrT f = LoadOrDie();                                    \
95       return f;                                                           \
96     }                                                                     \
97     template <typename... Args>                                           \
98     hiprandStatus operator()(GpuExecutor* parent, Args... args) {         \
99       gpu::ScopedActivateExecutorContext sac{parent};                     \
100       return DynLoad()(args...);                                          \
101     }                                                                     \
102   } __name;                                                               \
103   const char* DynLoadShim__##__name::kName = #__name;
104 
105 #endif
106 
107 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator);
108 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator);
109 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream);
110 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniform);
111 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniformDouble);
112 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed);
113 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetGeneratorOffset);
114 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormal);
115 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormalDouble);
116 
117 }  // namespace wrap
118 
GpuRng(GpuExecutor * parent)119 GpuRng::GpuRng(GpuExecutor* parent) : parent_(parent), rng_(nullptr) {}
120 
~GpuRng()121 GpuRng::~GpuRng() {
122   if (rng_ != nullptr) {
123     wrap::hiprandDestroyGenerator(parent_, rng_);
124   }
125 }
126 
Init()127 bool GpuRng::Init() {
128   absl::MutexLock lock{&mu_};
129   CHECK(rng_ == nullptr);
130 
131   hiprandStatus_t ret =
132       wrap::hiprandCreateGenerator(parent_, &rng_, HIPRAND_RNG_PSEUDO_DEFAULT);
133   if (ret != HIPRAND_STATUS_SUCCESS) {
134     LOG(ERROR) << "failed to create random number generator: " << ret;
135     return false;
136   }
137 
138   CHECK(rng_ != nullptr);
139   return true;
140 }
141 
SetStream(Stream * stream)142 bool GpuRng::SetStream(Stream* stream) {
143   hiprandStatus_t ret =
144       wrap::hiprandSetStream(parent_, rng_, AsGpuStreamValue(stream));
145   if (ret != HIPRAND_STATUS_SUCCESS) {
146     LOG(ERROR) << "failed to set stream for random generation: " << ret;
147     return false;
148   }
149 
150   return true;
151 }
152 
153 // Returns true if std::complex stores its contents as two consecutive
154 // elements. Tests int, float and double, as the last two are independent
155 // specializations.
ComplexIsConsecutiveFloats()156 constexpr bool ComplexIsConsecutiveFloats() {
157   return sizeof(std::complex<int>) == 8 && sizeof(std::complex<float>) == 8 &&
158          sizeof(std::complex<double>) == 16;
159 }
160 
161 template <typename T>
DoPopulateRandUniformInternal(Stream * stream,DeviceMemory<T> * v)162 bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v) {
163   absl::MutexLock lock{&mu_};
164   static_assert(ComplexIsConsecutiveFloats(),
165                 "std::complex values are not stored as consecutive values");
166 
167   if (!SetStream(stream)) {
168     return false;
169   }
170 
171   // std::complex<T> is currently implemented as two consecutive T variables.
172   uint64 element_count = v->ElementCount();
173   if (std::is_same<T, std::complex<float>>::value ||
174       std::is_same<T, std::complex<double>>::value) {
175     element_count *= 2;
176   }
177 
178   hiprandStatus_t ret;
179   if (std::is_same<T, float>::value ||
180       std::is_same<T, std::complex<float>>::value) {
181     ret = wrap::hiprandGenerateUniform(
182         parent_, rng_, reinterpret_cast<float*>(GpuMemoryMutable(v)),
183         element_count);
184   } else {
185     ret = wrap::hiprandGenerateUniformDouble(
186         parent_, rng_, reinterpret_cast<double*>(GpuMemoryMutable(v)),
187         element_count);
188   }
189   if (ret != HIPRAND_STATUS_SUCCESS) {
190     LOG(ERROR) << "failed to do uniform generation of " << v->ElementCount()
191                << " " << TypeString<T>() << "s at " << v->opaque() << ": "
192                << ret;
193     return false;
194   }
195 
196   return true;
197 }
198 
DoPopulateRandUniform(Stream * stream,DeviceMemory<float> * v)199 bool GpuRng::DoPopulateRandUniform(Stream* stream, DeviceMemory<float>* v) {
200   return DoPopulateRandUniformInternal(stream, v);
201 }
202 
DoPopulateRandUniform(Stream * stream,DeviceMemory<double> * v)203 bool GpuRng::DoPopulateRandUniform(Stream* stream, DeviceMemory<double>* v) {
204   return DoPopulateRandUniformInternal(stream, v);
205 }
206 
DoPopulateRandUniform(Stream * stream,DeviceMemory<std::complex<float>> * v)207 bool GpuRng::DoPopulateRandUniform(Stream* stream,
208                                    DeviceMemory<std::complex<float>>* v) {
209   return DoPopulateRandUniformInternal(stream, v);
210 }
211 
DoPopulateRandUniform(Stream * stream,DeviceMemory<std::complex<double>> * v)212 bool GpuRng::DoPopulateRandUniform(Stream* stream,
213                                    DeviceMemory<std::complex<double>>* v) {
214   return DoPopulateRandUniformInternal(stream, v);
215 }
216 
217 template <typename ElemT, typename FuncT>
DoPopulateRandGaussianInternal(Stream * stream,ElemT mean,ElemT stddev,DeviceMemory<ElemT> * v,FuncT func)218 bool GpuRng::DoPopulateRandGaussianInternal(Stream* stream, ElemT mean,
219                                             ElemT stddev,
220                                             DeviceMemory<ElemT>* v,
221                                             FuncT func) {
222   absl::MutexLock lock{&mu_};
223 
224   if (!SetStream(stream)) {
225     return false;
226   }
227 
228   uint64 element_count = v->ElementCount();
229   hiprandStatus_t ret =
230       func(parent_, rng_, GpuMemoryMutable(v), element_count, mean, stddev);
231 
232   if (ret != HIPRAND_STATUS_SUCCESS) {
233     LOG(ERROR) << "failed to do gaussian generation of " << v->ElementCount()
234                << " floats at " << v->opaque() << ": " << ret;
235     return false;
236   }
237 
238   return true;
239 }
240 
DoPopulateRandGaussian(Stream * stream,float mean,float stddev,DeviceMemory<float> * v)241 bool GpuRng::DoPopulateRandGaussian(Stream* stream, float mean, float stddev,
242                                     DeviceMemory<float>* v) {
243   return DoPopulateRandGaussianInternal(stream, mean, stddev, v,
244                                         wrap::hiprandGenerateNormal);
245 }
246 
DoPopulateRandGaussian(Stream * stream,double mean,double stddev,DeviceMemory<double> * v)247 bool GpuRng::DoPopulateRandGaussian(Stream* stream, double mean, double stddev,
248                                     DeviceMemory<double>* v) {
249   return DoPopulateRandGaussianInternal(stream, mean, stddev, v,
250                                         wrap::hiprandGenerateNormalDouble);
251 }
252 
SetSeed(Stream * stream,const uint8 * seed,uint64 seed_bytes)253 bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
254   absl::MutexLock lock{&mu_};
255   CHECK(rng_ != nullptr);
256 
257   if (!CheckSeed(seed, seed_bytes)) {
258     return false;
259   }
260 
261   if (!SetStream(stream)) {
262     return false;
263   }
264 
265   // Requires 8 bytes of seed data; checked in RngSupport::CheckSeed (above)
266   // (which itself requires 16 for API consistency with host RNG fallbacks).
267   hiprandStatus_t ret = wrap::hiprandSetPseudoRandomGeneratorSeed(
268       parent_, rng_, *(reinterpret_cast<const uint64*>(seed)));
269   if (ret != HIPRAND_STATUS_SUCCESS) {
270     LOG(ERROR) << "failed to set rng seed: " << ret;
271     return false;
272   }
273 
274   ret = wrap::hiprandSetGeneratorOffset(parent_, rng_, 0);
275   if (ret != HIPRAND_STATUS_SUCCESS) {
276     LOG(ERROR) << "failed to reset rng position: " << ret;
277     return false;
278   }
279   return true;
280 }
281 
282 }  // namespace gpu
283 
initialize_rocrand()284 void initialize_rocrand() {
285   auto rocRandAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
286       rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
287 
288   if (!rocRandAlreadyRegistered) {
289     port::Status status =
290         PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
291             rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
292             [](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
293               gpu::GpuExecutor* rocm_executor =
294                   dynamic_cast<gpu::GpuExecutor*>(parent);
295               if (rocm_executor == nullptr) {
296                 LOG(ERROR)
297                     << "Attempting to initialize an instance of the hipRAND "
298                     << "support library with a non-ROCM StreamExecutor";
299                 return nullptr;
300               }
301 
302               gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
303               if (!rng->Init()) {
304                 // Note: Init() will log a more specific error.
305                 delete rng;
306                 return nullptr;
307               }
308               return rng;
309             });
310 
311     if (!status.ok()) {
312       LOG(ERROR) << "Unable to register rocRAND factory: "
313                  << status.error_message();
314     }
315 
316     PluginRegistry::Instance()->SetDefaultFactory(
317         rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
318   }
319 }
320 
321 }  // namespace stream_executor
322 
323 REGISTER_MODULE_INITIALIZER(register_rocrand,
324                             { stream_executor::initialize_rocrand(); });
325