• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <gtest/gtest.h>
2 #include <ATen/test/rng_test.h>
3 #include <ATen/Generator.h>
4 #include <c10/core/GeneratorImpl.h>
5 #include <ATen/Tensor.h>
6 #include <ATen/native/DistributionTemplates.h>
7 #include <ATen/native/cpu/DistributionTemplates.h>
8 #include <torch/library.h>
9 #include <optional>
10 #include <torch/all.h>
11 #include <stdexcept>
12 
13 using namespace at;
14 
15 #ifndef ATEN_CPU_STATIC_DISPATCH
16 namespace {
17 
18 constexpr auto kCustomRNG = DispatchKey::CustomRNGKeyId;
19 
20 struct TestCPUGenerator : public c10::GeneratorImpl {
TestCPUGenerator__anon6d9ca2b10111::TestCPUGenerator21   TestCPUGenerator(uint64_t value) : GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(kCustomRNG)}, value_(value) { }
22   ~TestCPUGenerator() override = default;
random__anon6d9ca2b10111::TestCPUGenerator23   uint32_t random() { return value_; }
random64__anon6d9ca2b10111::TestCPUGenerator24   uint64_t random64() { return value_; }
next_float_normal_sample__anon6d9ca2b10111::TestCPUGenerator25   std::optional<float> next_float_normal_sample() { return next_float_normal_sample_; }
next_double_normal_sample__anon6d9ca2b10111::TestCPUGenerator26   std::optional<double> next_double_normal_sample() { return next_double_normal_sample_; }
set_next_float_normal_sample__anon6d9ca2b10111::TestCPUGenerator27   void set_next_float_normal_sample(std::optional<float> randn) { next_float_normal_sample_ = randn; }
set_next_double_normal_sample__anon6d9ca2b10111::TestCPUGenerator28   void set_next_double_normal_sample(std::optional<double> randn) { next_double_normal_sample_ = randn; }
set_current_seed__anon6d9ca2b10111::TestCPUGenerator29   void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); }
set_offset__anon6d9ca2b10111::TestCPUGenerator30   void set_offset(uint64_t offset) override { throw std::runtime_error("not implemented"); }
get_offset__anon6d9ca2b10111::TestCPUGenerator31   uint64_t get_offset() const override { throw std::runtime_error("not implemented"); }
current_seed__anon6d9ca2b10111::TestCPUGenerator32   uint64_t current_seed() const override { throw std::runtime_error("not implemented"); }
seed__anon6d9ca2b10111::TestCPUGenerator33   uint64_t seed() override { throw std::runtime_error("not implemented"); }
set_state__anon6d9ca2b10111::TestCPUGenerator34   void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); }
get_state__anon6d9ca2b10111::TestCPUGenerator35   c10::intrusive_ptr<c10::TensorImpl> get_state() const override { throw std::runtime_error("not implemented"); }
clone_impl__anon6d9ca2b10111::TestCPUGenerator36   TestCPUGenerator* clone_impl() const override { throw std::runtime_error("not implemented"); }
37 
device_type__anon6d9ca2b10111::TestCPUGenerator38   static DeviceType device_type() { return DeviceType::CPU; }
39 
40   uint64_t value_;
41   std::optional<float> next_float_normal_sample_;
42   std::optional<double> next_double_normal_sample_;
43 };
44 
45 // ==================================================== Random ========================================================
46 
random_(Tensor & self,std::optional<Generator> generator)47 Tensor& random_(Tensor& self, std::optional<Generator> generator) {
48   return at::native::templates::random_impl<native::templates::cpu::RandomKernel, TestCPUGenerator>(self, generator);
49 }
50 
random_from_to(Tensor & self,int64_t from,std::optional<int64_t> to,std::optional<Generator> generator)51 Tensor& random_from_to(Tensor& self, int64_t from, std::optional<int64_t> to, std::optional<Generator> generator) {
52   return at::native::templates::random_from_to_impl<native::templates::cpu::RandomFromToKernel, TestCPUGenerator>(self, from, to, generator);
53 }
54 
random_to(Tensor & self,int64_t to,std::optional<Generator> generator)55 Tensor& random_to(Tensor& self, int64_t to, std::optional<Generator> generator) {
56   return random_from_to(self, 0, to, generator);
57 }
58 
59 // ==================================================== Normal ========================================================
60 
normal_(Tensor & self,double mean,double std,std::optional<Generator> gen)61 Tensor& normal_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
62   return at::native::templates::normal_impl_<native::templates::cpu::NormalKernel, TestCPUGenerator>(self, mean, std, gen);
63 }
64 
normal_Tensor_float_out(const Tensor & mean,double std,std::optional<Generator> gen,Tensor & output)65 Tensor& normal_Tensor_float_out(const Tensor& mean, double std, std::optional<Generator> gen, Tensor& output) {
66   return at::native::templates::normal_out_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(output, mean, std, gen);
67 }
68 
normal_float_Tensor_out(double mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)69 Tensor& normal_float_Tensor_out(double mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
70   return at::native::templates::normal_out_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(output, mean, std, gen);
71 }
72 
normal_Tensor_Tensor_out(const Tensor & mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)73 Tensor& normal_Tensor_Tensor_out(const Tensor& mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
74   return at::native::templates::normal_out_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(output, mean, std, gen);
75 }
76 
normal_Tensor_float(const Tensor & mean,double std,std::optional<Generator> gen)77 Tensor normal_Tensor_float(const Tensor& mean, double std, std::optional<Generator> gen) {
78   return at::native::templates::normal_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(mean, std, gen);
79 }
80 
normal_float_Tensor(double mean,const Tensor & std,std::optional<Generator> gen)81 Tensor normal_float_Tensor(double mean, const Tensor& std, std::optional<Generator> gen) {
82   return at::native::templates::normal_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(mean, std, gen);
83 }
84 
normal_Tensor_Tensor(const Tensor & mean,const Tensor & std,std::optional<Generator> gen)85 Tensor normal_Tensor_Tensor(const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
86   return at::native::templates::normal_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(mean, std, gen);
87 }
88 
89 // ==================================================== Uniform =======================================================
90 
uniform_(Tensor & self,double from,double to,std::optional<Generator> generator)91 Tensor& uniform_(Tensor& self, double from, double to, std::optional<Generator> generator) {
92   return at::native::templates::uniform_impl_<native::templates::cpu::UniformKernel, TestCPUGenerator>(self, from, to, generator);
93 }
94 
95 // ==================================================== Cauchy ========================================================
96 
cauchy_(Tensor & self,double median,double sigma,std::optional<Generator> generator)97 Tensor& cauchy_(Tensor& self, double median, double sigma, std::optional<Generator> generator) {
98   return at::native::templates::cauchy_impl_<native::templates::cpu::CauchyKernel, TestCPUGenerator>(self, median, sigma, generator);
99 }
100 
101 // ================================================== LogNormal =======================================================
102 
log_normal_(Tensor & self,double mean,double std,std::optional<Generator> gen)103 Tensor& log_normal_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
104   return at::native::templates::log_normal_impl_<native::templates::cpu::LogNormalKernel, TestCPUGenerator>(self, mean, std, gen);
105 }
106 
107 // ================================================== Geometric =======================================================
108 
geometric_(Tensor & self,double p,std::optional<Generator> gen)109 Tensor& geometric_(Tensor& self, double p, std::optional<Generator> gen) {
110   return at::native::templates::geometric_impl_<native::templates::cpu::GeometricKernel, TestCPUGenerator>(self, p, gen);
111 }
112 
113 // ================================================== Exponential =====================================================
114 
exponential_(Tensor & self,double lambda,std::optional<Generator> gen)115 Tensor& exponential_(Tensor& self, double lambda, std::optional<Generator> gen) {
116   return at::native::templates::exponential_impl_<native::templates::cpu::ExponentialKernel, TestCPUGenerator>(self, lambda, gen);
117 }
118 
119 // ================================================== Bernoulli =======================================================
120 
bernoulli_Tensor(Tensor & self,const Tensor & p_,std::optional<Generator> gen)121 Tensor& bernoulli_Tensor(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
122   return at::native::templates::bernoulli_impl_<native::templates::cpu::BernoulliKernel, TestCPUGenerator>(self, p_, gen);
123 }
124 
bernoulli_float(Tensor & self,double p,std::optional<Generator> gen)125 Tensor& bernoulli_float(Tensor& self, double p, std::optional<Generator> gen) {
126   return at::native::templates::bernoulli_impl_<native::templates::cpu::BernoulliKernel, TestCPUGenerator>(self, p, gen);
127 }
128 
bernoulli_out(const Tensor & self,std::optional<Generator> gen,Tensor & result)129 Tensor& bernoulli_out(const Tensor& self, std::optional<Generator> gen, Tensor& result) {
130   return at::native::templates::bernoulli_out_impl<native::templates::cpu::BernoulliKernel, TestCPUGenerator>(result, self, gen);
131 }
132 
TORCH_LIBRARY_IMPL(aten,CustomRNGKeyId,m)133 TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
134   // Random
135   m.impl("random_.from",             random_from_to);
136   m.impl("random_.to",               random_to);
137   m.impl("random_",                  random_);
138   // Normal
139   m.impl("normal_",                  normal_);
140   m.impl("normal.Tensor_float_out",  normal_Tensor_float_out);
141   m.impl("normal.float_Tensor_out",  normal_float_Tensor_out);
142   m.impl("normal.Tensor_Tensor_out", normal_Tensor_Tensor_out);
143   m.impl("normal.Tensor_float",      normal_Tensor_float);
144   m.impl("normal.float_Tensor",      normal_float_Tensor);
145   m.impl("normal.Tensor_Tensor",     normal_Tensor_Tensor);
146   m.impl("uniform_",                 uniform_);
147   // Cauchy
148   m.impl("cauchy_",                  cauchy_);
149   // LogNormal
150   m.impl("log_normal_",              log_normal_);
151   // Geometric
152   m.impl("geometric_",               geometric_);
153   // Exponential
154   m.impl("exponential_",             exponential_);
155   // Bernoulli
156   m.impl("bernoulli.out",            bernoulli_out);
157   m.impl("bernoulli_.Tensor",        bernoulli_Tensor);
158   m.impl("bernoulli_.float",         bernoulli_float);
159 }
160 
161 class RNGTest : public ::testing::Test {
162 };
163 
164 static constexpr auto MAGIC_NUMBER = 424242424242424242ULL;
165 
166 // ==================================================== Random ========================================================
167 
TEST_F(RNGTest,RandomFromTo)168 TEST_F(RNGTest, RandomFromTo) {
169   const at::Device device("cpu");
170   test_random_from_to<TestCPUGenerator, torch::kBool, bool>(device);
171   test_random_from_to<TestCPUGenerator, torch::kUInt8, uint8_t>(device);
172   test_random_from_to<TestCPUGenerator, torch::kInt8, int8_t>(device);
173   test_random_from_to<TestCPUGenerator, torch::kInt16, int16_t>(device);
174   test_random_from_to<TestCPUGenerator, torch::kInt32, int32_t>(device);
175   test_random_from_to<TestCPUGenerator, torch::kInt64, int64_t>(device);
176   test_random_from_to<TestCPUGenerator, torch::kFloat32, float>(device);
177   test_random_from_to<TestCPUGenerator, torch::kFloat64, double>(device);
178 }
179 
TEST_F(RNGTest,Random)180 TEST_F(RNGTest, Random) {
181   const at::Device device("cpu");
182   test_random<TestCPUGenerator, torch::kBool, bool>(device);
183   test_random<TestCPUGenerator, torch::kUInt8, uint8_t>(device);
184   test_random<TestCPUGenerator, torch::kInt8, int8_t>(device);
185   test_random<TestCPUGenerator, torch::kInt16, int16_t>(device);
186   test_random<TestCPUGenerator, torch::kInt32, int32_t>(device);
187   test_random<TestCPUGenerator, torch::kInt64, int64_t>(device);
188   test_random<TestCPUGenerator, torch::kFloat32, float>(device);
189   test_random<TestCPUGenerator, torch::kFloat64, double>(device);
190 }
191 
192 // This test proves that Tensor.random_() distribution is able to generate unsigned 64 bit max value(64 ones)
193 // https://github.com/pytorch/pytorch/issues/33299
TEST_F(RNGTest,Random64bits)194 TEST_F(RNGTest, Random64bits) {
195   auto gen = at::make_generator<TestCPUGenerator>(std::numeric_limits<uint64_t>::max());
196   auto actual = torch::empty({1}, torch::kInt64);
197   actual.random_(std::numeric_limits<int64_t>::min(), std::nullopt, gen);
198   ASSERT_EQ(static_cast<uint64_t>(actual[0].item<int64_t>()), std::numeric_limits<uint64_t>::max());
199 }
200 
201 // ==================================================== Normal ========================================================
202 
TEST_F(RNGTest,Normal)203 TEST_F(RNGTest, Normal) {
204   const auto mean = 123.45;
205   const auto std = 67.89;
206   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
207 
208   auto actual = torch::empty({10});
209   actual.normal_(mean, std, gen);
210 
211   auto expected = torch::empty_like(actual);
212   native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
213 
214   ASSERT_TRUE(torch::allclose(actual, expected));
215 }
216 
TEST_F(RNGTest,Normal_float_Tensor_out)217 TEST_F(RNGTest, Normal_float_Tensor_out) {
218   const auto mean = 123.45;
219   const auto std = 67.89;
220   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
221 
222   auto actual = torch::empty({10});
223   at::normal_out(actual, mean, torch::full({10}, std), gen);
224 
225   auto expected = torch::empty_like(actual);
226   native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
227 
228   ASSERT_TRUE(torch::allclose(actual, expected));
229 }
230 
TEST_F(RNGTest,Normal_Tensor_float_out)231 TEST_F(RNGTest, Normal_Tensor_float_out) {
232   const auto mean = 123.45;
233   const auto std = 67.89;
234   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
235 
236   auto actual = torch::empty({10});
237   at::normal_out(actual, torch::full({10}, mean), std, gen);
238 
239   auto expected = torch::empty_like(actual);
240   native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
241 
242   ASSERT_TRUE(torch::allclose(actual, expected));
243 }
244 
TEST_F(RNGTest,Normal_Tensor_Tensor_out)245 TEST_F(RNGTest, Normal_Tensor_Tensor_out) {
246   const auto mean = 123.45;
247   const auto std = 67.89;
248   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
249 
250   auto actual = torch::empty({10});
251   at::normal_out(actual, torch::full({10}, mean), torch::full({10}, std), gen);
252 
253   auto expected = torch::empty_like(actual);
254   native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
255 
256   ASSERT_TRUE(torch::allclose(actual, expected));
257 }
258 
TEST_F(RNGTest,Normal_float_Tensor)259 TEST_F(RNGTest, Normal_float_Tensor) {
260   const auto mean = 123.45;
261   const auto std = 67.89;
262   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
263 
264   auto actual = at::normal(mean, torch::full({10}, std), gen);
265 
266   auto expected = torch::empty_like(actual);
267   native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
268 
269   ASSERT_TRUE(torch::allclose(actual, expected));
270 }
271 
TEST_F(RNGTest,Normal_Tensor_float)272 TEST_F(RNGTest, Normal_Tensor_float) {
273   const auto mean = 123.45;
274   const auto std = 67.89;
275   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
276 
277   auto actual = at::normal(torch::full({10}, mean), std, gen);
278 
279   auto expected = torch::empty_like(actual);
280   native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
281 
282   ASSERT_TRUE(torch::allclose(actual, expected));
283 }
284 
TEST_F(RNGTest,Normal_Tensor_Tensor)285 TEST_F(RNGTest, Normal_Tensor_Tensor) {
286   const auto mean = 123.45;
287   const auto std = 67.89;
288   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
289 
290   auto actual = at::normal(torch::full({10}, mean), torch::full({10}, std), gen);
291 
292   auto expected = torch::empty_like(actual);
293   native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
294 
295   ASSERT_TRUE(torch::allclose(actual, expected));
296 }
297 
298 // ==================================================== Uniform =======================================================
299 
TEST_F(RNGTest,Uniform)300 TEST_F(RNGTest, Uniform) {
301   const auto from = -24.24;
302   const auto to = 42.42;
303   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
304 
305   auto actual = torch::empty({3, 3});
306   actual.uniform_(from, to, gen);
307 
308   auto expected = torch::empty_like(actual);
309   auto iter = TensorIterator::nullary_op(expected);
310   native::templates::cpu::uniform_kernel(iter, from, to, check_generator<TestCPUGenerator>(gen));
311 
312   ASSERT_TRUE(torch::allclose(actual, expected));
313 }
314 
315 // ==================================================== Cauchy ========================================================
316 
TEST_F(RNGTest,Cauchy)317 TEST_F(RNGTest, Cauchy) {
318   const auto median = 123.45;
319   const auto sigma = 67.89;
320   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
321 
322   auto actual = torch::empty({3, 3});
323   actual.cauchy_(median, sigma, gen);
324 
325   auto expected = torch::empty_like(actual);
326   auto iter = TensorIterator::nullary_op(expected);
327   native::templates::cpu::cauchy_kernel(iter, median, sigma, check_generator<TestCPUGenerator>(gen));
328 
329   ASSERT_TRUE(torch::allclose(actual, expected));
330 }
331 
332 // ================================================== LogNormal =======================================================
333 
TEST_F(RNGTest,LogNormal)334 TEST_F(RNGTest, LogNormal) {
335   const auto mean = 12.345;
336   const auto std = 6.789;
337   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
338 
339   auto actual = torch::empty({10});
340   actual.log_normal_(mean, std, gen);
341 
342   auto expected = torch::empty_like(actual);
343   auto iter = TensorIterator::nullary_op(expected);
344   native::templates::cpu::log_normal_kernel(iter, mean, std, check_generator<TestCPUGenerator>(gen));
345 
346   ASSERT_TRUE(torch::allclose(actual, expected));
347 }
348 
349 // ================================================== Geometric =======================================================
350 
TEST_F(RNGTest,Geometric)351 TEST_F(RNGTest, Geometric) {
352   const auto p = 0.42;
353   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
354 
355   auto actual = torch::empty({3, 3});
356   actual.geometric_(p, gen);
357 
358   auto expected = torch::empty_like(actual);
359   auto iter = TensorIterator::nullary_op(expected);
360   native::templates::cpu::geometric_kernel(iter, p, check_generator<TestCPUGenerator>(gen));
361 
362   ASSERT_TRUE(torch::allclose(actual, expected));
363 }
364 
365 // ================================================== Exponential =====================================================
366 
TEST_F(RNGTest,Exponential)367 TEST_F(RNGTest, Exponential) {
368   const auto lambda = 42;
369   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
370 
371   auto actual = torch::empty({3, 3});
372   actual.exponential_(lambda, gen);
373 
374   auto expected = torch::empty_like(actual);
375   auto iter = TensorIterator::nullary_op(expected);
376   native::templates::cpu::exponential_kernel(iter, lambda, check_generator<TestCPUGenerator>(gen));
377 
378   ASSERT_TRUE(torch::allclose(actual, expected));
379 }
380 
381 // ==================================================== Bernoulli =====================================================
382 
TEST_F(RNGTest,Bernoulli_Tensor)383 TEST_F(RNGTest, Bernoulli_Tensor) {
384   const auto p = 0.42;
385   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
386 
387   auto actual = torch::empty({3, 3});
388   actual.bernoulli_(torch::full({3,3}, p), gen);
389 
390   auto expected = torch::empty_like(actual);
391   native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
392 
393   ASSERT_TRUE(torch::allclose(actual, expected));
394 }
395 
TEST_F(RNGTest,Bernoulli_scalar)396 TEST_F(RNGTest, Bernoulli_scalar) {
397   const auto p = 0.42;
398   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
399 
400   auto actual = torch::empty({3, 3});
401   actual.bernoulli_(p, gen);
402 
403   auto expected = torch::empty_like(actual);
404   native::templates::cpu::bernoulli_kernel(expected, p, check_generator<TestCPUGenerator>(gen));
405 
406   ASSERT_TRUE(torch::allclose(actual, expected));
407 }
408 
TEST_F(RNGTest,Bernoulli)409 TEST_F(RNGTest, Bernoulli) {
410   const auto p = 0.42;
411   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
412 
413   auto actual = at::bernoulli(torch::full({3,3}, p), gen);
414 
415   auto expected = torch::empty_like(actual);
416   native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
417 
418   ASSERT_TRUE(torch::allclose(actual, expected));
419 }
420 
TEST_F(RNGTest,Bernoulli_2)421 TEST_F(RNGTest, Bernoulli_2) {
422   const auto p = 0.42;
423   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
424 
425   auto actual = torch::full({3,3}, p).bernoulli(gen);
426 
427   auto expected = torch::empty_like(actual);
428   native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
429 
430   ASSERT_TRUE(torch::allclose(actual, expected));
431 }
432 
TEST_F(RNGTest,Bernoulli_p)433 TEST_F(RNGTest, Bernoulli_p) {
434   const auto p = 0.42;
435   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
436 
437   auto actual = at::bernoulli(torch::empty({3, 3}), p, gen);
438 
439   auto expected = torch::empty_like(actual);
440   native::templates::cpu::bernoulli_kernel(expected, p, check_generator<TestCPUGenerator>(gen));
441 
442   ASSERT_TRUE(torch::allclose(actual, expected));
443 }
444 
TEST_F(RNGTest,Bernoulli_p_2)445 TEST_F(RNGTest, Bernoulli_p_2) {
446   const auto p = 0.42;
447   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
448 
449   auto actual = torch::empty({3, 3}).bernoulli(p, gen);
450 
451   auto expected = torch::empty_like(actual);
452   native::templates::cpu::bernoulli_kernel(expected, p, check_generator<TestCPUGenerator>(gen));
453 
454   ASSERT_TRUE(torch::allclose(actual, expected));
455 }
456 
TEST_F(RNGTest,Bernoulli_out)457 TEST_F(RNGTest, Bernoulli_out) {
458   const auto p = 0.42;
459   auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
460 
461   auto actual = torch::empty({3, 3});
462   at::bernoulli_out(actual, torch::full({3,3}, p), gen);
463 
464   auto expected = torch::empty_like(actual);
465   native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
466 
467   ASSERT_TRUE(torch::allclose(actual, expected));
468 }
469 }
470 #endif // ATEN_CPU_STATIC_DISPATCH
471