• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_RANDOM_H_
18 #define MINDSPORE_CCSRC_INCLUDE_COMMON_RANDOM_H_
19 
20 #include <cstdint>
21 #include <cmath>
22 #include <array>
23 #include <limits>
24 #include <random>
25 #include <vector>
26 #include <optional>
27 #include <algorithm>
28 #include <utility>
29 #include "include/common/thread_pool.h"
30 #include "include/common/utils/utils.h"
31 #include "utils/log_adapter.h"
32 
33 namespace mindspore::random {
34 //
35 // Generate random numbers into a buffer.
36 //
37 template <typename T, typename Generator, typename Distribution, typename... Args>
GenerateRandoms(std::uint64_t seed,size_t skip,T * buf,size_t size,Args...args)38 void GenerateRandoms(std::uint64_t seed, size_t skip, T *buf, size_t size, Args... args) {
39   MS_EXCEPTION_IF_NULL(buf);
40   Generator gen{seed};
41   gen.discard(skip);
42   Distribution dis{args...};
43   for (size_t i = 0; i < size; ++i) {
44     buf[i] = T(dis(gen));
45   }
46 }
47 
48 // Compute number of task and batch size of each task.
ComputeTaskNumSize(size_t total_size,size_t thread_num)49 static inline std::pair<size_t, size_t> ComputeTaskNumSize(size_t total_size, size_t thread_num) {
50   constexpr size_t min_parallel_size = 1024;
51   if (thread_num == 0 || total_size <= min_parallel_size) {
52     return {1, total_size};
53   }
54   constexpr size_t block_size = 4;
55   const size_t block_count = (total_size + block_size - 1) / block_size;
56   if (block_count <= thread_num) {
57     return {block_count, block_size};
58   }
59   const size_t blocks_per_thread = (block_count + thread_num - 1) / thread_num;
60   const size_t task_num = (block_count + blocks_per_thread - 1) / blocks_per_thread;
61   const size_t batch_size = blocks_per_thread * block_size;
62   return {task_num, batch_size};
63 }
64 
65 //
66 // Parallel generate random numbers into a buffer.
67 //
68 template <typename T, typename Generator, typename Distribution, typename... Args>
GenerateRandomsParallel(std::uint64_t input_seed,T * buf,size_t buf_size,Args...args)69 void GenerateRandomsParallel(std::uint64_t input_seed, T *buf, size_t buf_size, Args... args) {
70   MS_EXCEPTION_IF_NULL(buf);
71 
72   // Calculate number of tasks and batch size.
73   auto &thread_pool = common::ThreadPool::GetInstance();
74   auto [task_num, batch_size] = ComputeTaskNumSize(buf_size, thread_pool.GetSyncRunThreadNum());
75 
76   // Generate random seed if required.
77   std::uint64_t seed = input_seed;
78   if (seed == 0) {
79     std::random_device rd;
80     seed = rd();
81   }
82 
83   if (task_num == 1) {
84     // Use single thread for small data size.
85     GenerateRandoms<T, Generator, Distribution>(seed, 0, buf, buf_size, args...);
86     return;
87   }
88 
89   // Prepare parallel tasks.
90   std::vector<common::Task> tasks;
91   tasks.reserve(task_num);
92   T *task_buf = buf;
93   size_t skip = 0;
94   for (size_t i = 0; i < task_num; ++i) {
95     const auto task_size = ((i == task_num - 1) ? (buf_size - (task_num - 1) * batch_size) : batch_size);
96     (void)tasks.emplace_back([seed, skip, task_buf, task_size, args...]() {
97       GenerateRandoms<T, Generator, Distribution>(seed, skip, task_buf, task_size, args...);
98       return common::SUCCESS;
99     });
100     skip += task_size;
101     task_buf += task_size;
102   }
103   // Parallel execute tasks by thread pool.
104   (void)thread_pool.SyncRun(tasks);
105 }
106 
107 //
108 // Philox is a random number generator that is suitable for parallel random number generating.
109 //
110 class Philox {
111  public:
Philox(uint64_t seed)112   explicit Philox(uint64_t seed)
113       : key_({static_cast<uint32_t>(seed), static_cast<uint32_t>(seed >> kShift32)}),
114         counter_({0, 0, static_cast<uint32_t>(seed), static_cast<uint32_t>(seed >> kShift32)}),
115         results_({}) {}
116 
Philox(uint64_t seed,uint64_t seed2)117   Philox(uint64_t seed, uint64_t seed2)
118       : key_({static_cast<uint32_t>(seed), static_cast<uint32_t>(seed >> kShift32)}),
119         counter_({0, 0, static_cast<uint32_t>(seed2), static_cast<uint32_t>(seed2 >> kShift32)}),
120         results_({}) {}
121 
122   ~Philox() = default;
123 
operator()124   uint32_t operator()() {
125     if (index_ == kCounterNum) {
126       results_ = next();
127       index_ = 0;
128     }
129     return results_[index_++];
130   }
131 
discard(uint64_t step)132   void discard(uint64_t step) {
133     if (index_ == kCounterNum) {
134       const auto count = (step / kCounterNum);
135       skip(count);
136       const auto remain = (step % kCounterNum);
137       if (remain > 0) {
138         results_ = next();
139         index_ = remain;
140       }
141     } else {
142       const auto pos = index_ + step;
143       if (pos <= kCounterNum) {
144         index_ = pos;
145       } else {
146         const auto count = (pos - kCounterNum) / kCounterNum;
147         skip(count);
148         const auto remain = (pos % kCounterNum);
149         if (remain > 0) {
150           results_ = next();
151           index_ = remain;
152         } else {
153           index_ = kCounterNum;
154         }
155       }
156     }
157   }
158 
min()159   static constexpr uint32_t min() { return 0; }
max()160   static constexpr uint32_t max() { return std::numeric_limits<uint32_t>::max(); }
161 
162  private:
163   static constexpr int kShift32 = 32;
164   static constexpr size_t kCounterNum = 4;
165   static constexpr size_t kKeyNum = 2;
166   static constexpr size_t kIndex0 = 0;
167   static constexpr size_t kIndex1 = 1;
168   static constexpr size_t kIndex2 = 2;
169   static constexpr size_t kIndex3 = 3;
170   static constexpr uint32_t kMagic0 = 0xD2511F53;
171   static constexpr uint32_t kMagic1 = 0xCD9E8D57;
172   static constexpr uint32_t kKeyStep0 = 0x9E3779B9;
173   static constexpr uint32_t kKeyStep1 = 0xBB67AE85;
174 
175   using Counter = std::array<uint32_t, kCounterNum>;
176   using Key = std::array<uint32_t, kKeyNum>;
177 
178   Key key_;
179   Counter counter_;
180   Counter results_;
181   size_t index_ = kCounterNum;
182 
compute(uint32_t * counter,const uint32_t * key)183   static void compute(uint32_t *counter, const uint32_t *key) {
184     const uint64_t t0 = static_cast<uint64_t>(kMagic0) * counter[kIndex0];
185     const uint32_t l0 = static_cast<uint32_t>(t0);
186     const uint32_t h0 = static_cast<uint32_t>(t0 >> kShift32);
187     const uint64_t t1 = static_cast<uint64_t>(kMagic1) * counter[kIndex2];
188     const uint32_t l1 = static_cast<uint32_t>(t1);
189     const uint32_t h1 = static_cast<uint32_t>(t1 >> kShift32);
190     counter[kIndex0] = (h1 ^ counter[kIndex1] ^ key[kIndex0]);
191     counter[kIndex1] = l1;
192     counter[kIndex2] = (h0 ^ counter[kIndex3] ^ key[kIndex1]);
193     counter[kIndex3] = l0;
194   }
195 
raise_key(uint32_t * key)196   static void raise_key(uint32_t *key) {
197     key[kIndex0] += kKeyStep0;
198     key[kIndex1] += kKeyStep1;
199   }
200 
201   // Generate next 4 random numbers and advance counter.
next()202   Counter next() {
203     Counter result = counter_;
204     Key key = key_;
205     // For performance reason, we do not use loop here,
206     // but manually call compute() 10 times.
207     compute(result.data(), key.data());
208     raise_key(key.data());
209     compute(result.data(), key.data());
210     raise_key(key.data());
211     compute(result.data(), key.data());
212     raise_key(key.data());
213     compute(result.data(), key.data());
214     raise_key(key.data());
215     compute(result.data(), key.data());
216     raise_key(key.data());
217     compute(result.data(), key.data());
218     raise_key(key.data());
219     compute(result.data(), key.data());
220     raise_key(key.data());
221     compute(result.data(), key.data());
222     raise_key(key.data());
223     compute(result.data(), key.data());
224     raise_key(key.data());
225     compute(result.data(), key.data());
226     skip_one();
227     return result;
228   }
229 
230   // Advance counter for one step.
skip_one()231   void skip_one() {
232     if (++counter_[kIndex0] == 0) {
233       if (++counter_[kIndex1] == 0) {
234         if (++counter_[kIndex2] == 0) {
235           ++counter_[kIndex3];
236         }
237       }
238     }
239   }
240 
241   // Skip the given number of samples of 4 uint32.
skip(uint64_t count)242   void skip(uint64_t count) {
243     const uint32_t lo = static_cast<uint32_t>(count);
244     uint32_t hi = static_cast<uint32_t>(count >> kShift32);
245     counter_[kIndex0] += lo;
246     if (counter_[kIndex0] < lo) {
247       ++hi;
248     }
249     counter_[kIndex1] += hi;
250     if (counter_[kIndex1] < hi) {
251       if (++counter_[kIndex2] == 0) {
252         ++counter_[kIndex3];
253       }
254     }
255   }
256 };
257 
258 //
259 // Uniform distribution.
260 //
261 template <typename T>
262 class UniformDistribution {
263  public:
UniformDistribution(T a,T b)264   UniformDistribution(T a, T b) : a_(a), b_(b) {}
265   ~UniformDistribution() = default;
266 
267   template <typename Generator>
operator()268   T operator()(Generator &&g) const {
269     const auto min_num = g.min();
270     const auto max_num = g.max();
271     const long double range = static_cast<long double>(max_num) - static_cast<long double>(min_num) + 1.0L;
272     T s = static_cast<T>(T(g() - min_num) / range);
273     if (s >= T(1)) {
274       s = std::nextafter(T(1), T(0));
275     }
276     return (b_ - a_) * s + a_;
277   }
278 
279  private:
280   T a_;
281   T b_;
282 };  // namespace mindspore::random
283 
284 //
285 // Normal distribution.
286 //
287 template <typename T>
288 class NormalDistribution {
289  public:
NormalDistribution(T mean,T sigma)290   NormalDistribution(T mean, T sigma) : mean_(mean), sigma_(sigma) {}
291   ~NormalDistribution() = default;
292 
293   template <typename Generator>
operator()294   T operator()(Generator &&g) const {
295     if (has_next_) {
296       has_next_ = false;
297       return next_;
298     }
299     // Box-Muller transform algorithm:
300     // z1 = sqrt(-2 * ln(u1)) * cos(2 * pi * u2)
301     // z2 = sqrt(-2 * ln(u1)) * sin(2 * pi * u2)
302     constexpr T pi = 3.1415926f;
303     constexpr T threshold = 1.0e-7f;
304     const T u1 = std::max(to_float(g()), threshold);
305     const T u2 = std::max(to_float(g()), threshold);
306     const T x = std::sqrt(-2.0f * std::log(u1)) * sigma_;
307     const T y = 2.0f * pi * u2;
308     next_ = mean_ + (x * std::sin(y));
309     has_next_ = true;
310     return mean_ + (x * std::cos(y));
311   }
312 
313  private:
314   T mean_;
315   T sigma_;
316   mutable T next_ = 0;
317   mutable bool has_next_ = false;
318 
to_float(uint32_t input)319   static T to_float(uint32_t input) {
320     constexpr uint32_t mask = 0x7fffffu;
321     constexpr uint32_t exp = (127 << 23);
322     union {
323       uint32_t int_val;
324       float float_val;
325     } val;
326     val.int_val = (input & mask) | exp;
327     return T(val.float_val - 1.0f);
328   }
329 };
330 
331 //
332 // Truncated normal distribution.
333 //
334 template <typename T>
335 class TruncatedNormal {
336  public:
TruncatedNormal(T a,T b,T mean,T sigma)337   TruncatedNormal(T a, T b, T mean, T sigma) : lower_(a), upper_(b), mean_(mean), sigma_(sigma) {
338     if (sigma <= 0) {
339       MS_LOG(EXCEPTION) << "TruncatedNormal: invalid sigma " << sigma << ".";
340     } else {
341       alpha_ = (a - mean) / sigma;
342       beta_ = (b - mean) / sigma;
343     }
344   }
345 
346   ~TruncatedNormal() = default;
347 
348   template <typename Generator>
operator()349   T operator()(Generator &&g) const {
350     // Inverse CDF (Cumulative Distribution Function) method.
351     const T u = std_uniform_(g);
352     const T cdf_a = cdf(alpha_);
353     const T cdf_b = cdf(beta_);
354     const T p = cdf_a + u * (cdf_b - cdf_a);
355     const T x = quantile(p);
356     return mean_ + x * sigma_;
357   }
358 
359  private:
360   UniformDistribution<T> std_uniform_{0.0f, 1.0f};
361   T lower_;
362   T upper_;
363   T mean_;
364   T sigma_;
365   T alpha_;
366   T beta_;
367 
368   static constexpr T kRootTwo = 1.4142135f;
369 
cdf(T x)370   static T cdf(T x) {
371     const T diff = x / kRootTwo;
372     return std::erfc(-diff) / 2.0f;
373   }
374 
quantile(T p)375   static T quantile(T p) {
376     auto z = 2.0f * p;
377     const T x = erfc_inv(z);
378     return -x * kRootTwo;
379   }
380 
erfc_inv(T z)381   static T erfc_inv(T z) {
382     // Keep z in range (0, 2).
383     if (z <= 0) {
384       z = std::nextafterf(0.0f, 2.0f);
385     } else if (z >= 2.0f) {
386       z = std::nextafterf(2.0f, 0.0f);
387     }
388     T p, q, s;
389     if (z > 1.0f) {
390       q = 2.0f - z;
391       p = 1.0f - q;
392       s = -1;
393     } else {
394       p = 1.0f - z;
395       q = z;
396       s = 1;
397     }
398     return s * erf_inv_imp(p, q);
399   }
400 
401   // The algorithm and polynomia constants are borrow from boost.
erf_inv_imp(T p,T q)402   static T erf_inv_imp(T p, T q) {
403     if (p <= 0.5f) {
404       constexpr float Y = 0.0891314744949340820313f;
405       constexpr T P[] = {T(-0.000508781949658280665617), T(-0.00836874819741736770379), T(0.0334806625409744615033),
406                          T(-0.0126926147662974029034),   T(-0.0365637971411762664006),  T(0.0219878681111168899165),
407                          T(0.00822687874676915743155),   T(-0.00538772965071242932965)};
408       constexpr T Q[] = {T(1.0),
409                          T(-0.970005043303290640362),
410                          T(-1.56574558234175846809),
411                          T(1.56221558398423026363),
412                          T(0.662328840472002992063),
413                          T(-0.71228902341542847553),
414                          T(-0.0527396382340099713954),
415                          T(0.0795283687341571680018),
416                          T(-0.00233393759374190016776),
417                          T(0.000886216390456424707504)};
418       T g = p * (p + 10.0f);
419       T r = eval_polynomial(P, p) / eval_polynomial(Q, p);
420       return g * Y + g * r;
421     }
422     if (q >= 0.25f) {
423       constexpr float Y = 2.249481201171875f;
424       constexpr T P[] = {T(-0.202433508355938759655), T(0.105264680699391713268), T(8.37050328343119927838),
425                          T(17.6447298408374015486),   T(-18.8510648058714251895), T(-44.6382324441786960818),
426                          T(17.445385985570866523),    T(21.1294655448340526258),  T(-3.67192254707729348546)};
427       constexpr T Q[] = {T(1.0),
428                          T(6.24264124854247537712),
429                          T(3.9713437953343869095),
430                          T(-28.6608180499800029974),
431                          T(-20.1432634680485188801),
432                          T(48.5609213108739935468),
433                          T(10.8268667355460159008),
434                          T(-22.6436933413139721736),
435                          T(1.72114765761200282724)};
436       T g = std::sqrt(-2.0f * std::log(q));
437       T xs = q - 0.25f;
438       T r = eval_polynomial(P, xs) / eval_polynomial(Q, xs);
439       return g / (Y + r);
440     }
441     // Avoid static check warning for 'function body too long'.
442     return erf_inv_imp2(q);
443   }
444 
erf_inv_imp2(T q)445   static T erf_inv_imp2(T q) {
446     T x = std::sqrt(-std::log(q));
447     if (x < 3.0f) {
448       constexpr float Y = 0.807220458984375f;
449       constexpr T P[] = {T(-0.131102781679951906451),   T(-0.163794047193317060787),   T(0.117030156341995252019),
450                          T(0.387079738972604337464),    T(0.337785538912035898924),    T(0.142869534408157156766),
451                          T(0.0290157910005329060432),   T(0.00214558995388805277169),  T(-0.679465575181126350155e-6),
452                          T(0.285225331782217055858e-7), T(-0.681149956853776992068e-9)};
453       constexpr T Q[] = {T(1.0),
454                          T(3.46625407242567245975),
455                          T(5.38168345707006855425),
456                          T(4.77846592945843778382),
457                          T(2.59301921623620271374),
458                          T(0.848854343457902036425),
459                          T(0.152264338295331783612),
460                          T(0.01105924229346489121)};
461       T xs = x - 1.125f;
462       T R = eval_polynomial(P, xs) / eval_polynomial(Q, xs);
463       return Y * x + R * x;
464     }
465     if (x < 6.0f) {
466       constexpr float Y = 0.93995571136474609375f;
467       constexpr T P[] = {T(-0.0350353787183177984712),  T(-0.00222426529213447927281),  T(0.0185573306514231072324),
468                          T(0.00950804701325919603619),  T(0.00187123492819559223345),   T(0.000157544617424960554631),
469                          T(0.460469890584317994083e-5), T(-0.230404776911882601748e-9), T(0.266339227425782031962e-11)};
470       constexpr T Q[] = {T(1.0),
471                          T(1.3653349817554063097),
472                          T(0.762059164553623404043),
473                          T(0.220091105764131249824),
474                          T(0.0341589143670947727934),
475                          T(0.00263861676657015992959),
476                          T(0.764675292302794483503e-4)};
477       T xs = x - 3.0f;
478       T R = eval_polynomial(P, xs) / eval_polynomial(Q, xs);
479       return Y * x + R * x;
480     }
481     if (x < 18.0f) {
482       constexpr float Y = 0.98362827301025390625f;
483       constexpr T P[] = {T(-0.0167431005076633737133),  T(-0.00112951438745580278863),   T(0.00105628862152492910091),
484                          T(0.000209386317487588078668), T(0.149624783758342370182e-4),   T(0.449696789927706453732e-6),
485                          T(0.462596163522878599135e-8), T(-0.281128735628831791805e-13), T(0.99055709973310326855e-16)};
486       constexpr T Q[] = {T(1.0),
487                          T(0.591429344886417493481),
488                          T(0.138151865749083321638),
489                          T(0.0160746087093676504695),
490                          T(0.000964011807005165528527),
491                          T(0.275335474764726041141e-4),
492                          T(0.282243172016108031869e-6)};
493       T xs = x - 6.0f;
494       T R = eval_polynomial(P, xs) / eval_polynomial(Q, xs);
495       return Y * x + R * x;
496     }
497     if (x < 44.0f) {
498       constexpr float Y = 0.99714565277099609375f;
499       constexpr T P[] = {T(-0.0024978212791898131227),   T(-0.779190719229053954292e-5), T(0.254723037413027451751e-4),
500                          T(0.162397777342510920873e-5),  T(0.396341011304801168516e-7),  T(0.411632831190944208473e-9),
501                          T(0.145596286718675035587e-11), T(-0.116765012397184275695e-17)};
502       constexpr T Q[] = {T(1.0),
503                          T(0.207123112214422517181),
504                          T(0.0169410838120975906478),
505                          T(0.000690538265622684595676),
506                          T(0.145007359818232637924e-4),
507                          T(0.144437756628144157666e-6),
508                          T(0.509761276599778486139e-9)};
509       T xs = x - 18.0f;
510       T R = eval_polynomial(P, xs) / eval_polynomial(Q, xs);
511       return Y * x + R * x;
512     }
513     constexpr float Y = 0.99941349029541015625f;
514     constexpr T P[] = {T(-0.000539042911019078575891), T(-0.28398759004727721098e-6),  T(0.899465114892291446442e-6),
515                        T(0.229345859265920864296e-7),  T(0.225561444863500149219e-9),  T(0.947846627503022684216e-12),
516                        T(0.135880130108924861008e-14), T(-0.348890393399948882918e-21)};
517     constexpr T Q[] = {T(1.0),
518                        T(0.0845746234001899436914),
519                        T(0.00282092984726264681981),
520                        T(0.468292921940894236786e-4),
521                        T(0.399968812193862100054e-6),
522                        T(0.161809290887904476097e-8),
523                        T(0.231558608310259605225e-11)};
524     T xs = x - 44.0f;
525     T R = eval_polynomial(P, xs) / eval_polynomial(Q, xs);
526     return Y * x + R * x;
527   }
528 
529   // We use template function to unrolling polynomial evaluations
530   // at compile time to improve performance.
531   template <size_t N>
eval_polynomial(const T (& arr)[N],T x)532   static T eval_polynomial(const T (&arr)[N], T x) {
533     T sum = arr[N - 1];
534     if constexpr (N > 1) {
535       eval_polynomial_loop<N - kIndex2>(arr, x, &sum);
536     }
537     return sum;
538   }
539 
540   template <size_t Index>
eval_polynomial_loop(const T * arr,T x,T * sum)541   static void eval_polynomial_loop(const T *arr, T x, T *sum) {
542     *sum *= x;
543     *sum += arr[Index];
544     if constexpr (Index > 0) {
545       eval_polynomial_loop<Index - 1>(arr, x, sum);
546     }
547   }
548 };
549 
550 //
551 // Constant distribution.
552 //
553 template <typename T>
554 class ConstantDistribution {
555  public:
ConstantDistribution(T value)556   explicit ConstantDistribution(T value) : value_(value) {}
557   ~ConstantDistribution() = default;
558 
559   template <typename Generator>
operator()560   T operator()(Generator &&) const {
561     return value_;
562   }
563 
564  private:
565   T value_;
566 };
567 }  // namespace mindspore::random
568 
569 #endif  // MINDSPORE_CCSRC_INCLUDE_COMMON_RANDOM_H_
570