1 #include <gtest/gtest.h>
2 #include <ATen/test/rng_test.h>
3 #include <ATen/Generator.h>
4 #include <c10/core/GeneratorImpl.h>
5 #include <ATen/Tensor.h>
6 #include <ATen/native/DistributionTemplates.h>
7 #include <ATen/native/cpu/DistributionTemplates.h>
8 #include <torch/library.h>
9 #include <optional>
10 #include <torch/all.h>
11 #include <stdexcept>
12
13 using namespace at;
14
15 #ifndef ATEN_CPU_STATIC_DISPATCH
16 namespace {
17
18 constexpr auto kCustomRNG = DispatchKey::CustomRNGKeyId;
19
20 struct TestCPUGenerator : public c10::GeneratorImpl {
TestCPUGenerator__anon6d9ca2b10111::TestCPUGenerator21 TestCPUGenerator(uint64_t value) : GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(kCustomRNG)}, value_(value) { }
22 ~TestCPUGenerator() override = default;
random__anon6d9ca2b10111::TestCPUGenerator23 uint32_t random() { return value_; }
random64__anon6d9ca2b10111::TestCPUGenerator24 uint64_t random64() { return value_; }
next_float_normal_sample__anon6d9ca2b10111::TestCPUGenerator25 std::optional<float> next_float_normal_sample() { return next_float_normal_sample_; }
next_double_normal_sample__anon6d9ca2b10111::TestCPUGenerator26 std::optional<double> next_double_normal_sample() { return next_double_normal_sample_; }
set_next_float_normal_sample__anon6d9ca2b10111::TestCPUGenerator27 void set_next_float_normal_sample(std::optional<float> randn) { next_float_normal_sample_ = randn; }
set_next_double_normal_sample__anon6d9ca2b10111::TestCPUGenerator28 void set_next_double_normal_sample(std::optional<double> randn) { next_double_normal_sample_ = randn; }
set_current_seed__anon6d9ca2b10111::TestCPUGenerator29 void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); }
set_offset__anon6d9ca2b10111::TestCPUGenerator30 void set_offset(uint64_t offset) override { throw std::runtime_error("not implemented"); }
get_offset__anon6d9ca2b10111::TestCPUGenerator31 uint64_t get_offset() const override { throw std::runtime_error("not implemented"); }
current_seed__anon6d9ca2b10111::TestCPUGenerator32 uint64_t current_seed() const override { throw std::runtime_error("not implemented"); }
seed__anon6d9ca2b10111::TestCPUGenerator33 uint64_t seed() override { throw std::runtime_error("not implemented"); }
set_state__anon6d9ca2b10111::TestCPUGenerator34 void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); }
get_state__anon6d9ca2b10111::TestCPUGenerator35 c10::intrusive_ptr<c10::TensorImpl> get_state() const override { throw std::runtime_error("not implemented"); }
clone_impl__anon6d9ca2b10111::TestCPUGenerator36 TestCPUGenerator* clone_impl() const override { throw std::runtime_error("not implemented"); }
37
device_type__anon6d9ca2b10111::TestCPUGenerator38 static DeviceType device_type() { return DeviceType::CPU; }
39
40 uint64_t value_;
41 std::optional<float> next_float_normal_sample_;
42 std::optional<double> next_double_normal_sample_;
43 };
44
45 // ==================================================== Random ========================================================
46
random_(Tensor & self,std::optional<Generator> generator)47 Tensor& random_(Tensor& self, std::optional<Generator> generator) {
48 return at::native::templates::random_impl<native::templates::cpu::RandomKernel, TestCPUGenerator>(self, generator);
49 }
50
random_from_to(Tensor & self,int64_t from,std::optional<int64_t> to,std::optional<Generator> generator)51 Tensor& random_from_to(Tensor& self, int64_t from, std::optional<int64_t> to, std::optional<Generator> generator) {
52 return at::native::templates::random_from_to_impl<native::templates::cpu::RandomFromToKernel, TestCPUGenerator>(self, from, to, generator);
53 }
54
random_to(Tensor & self,int64_t to,std::optional<Generator> generator)55 Tensor& random_to(Tensor& self, int64_t to, std::optional<Generator> generator) {
56 return random_from_to(self, 0, to, generator);
57 }
58
59 // ==================================================== Normal ========================================================
60
normal_(Tensor & self,double mean,double std,std::optional<Generator> gen)61 Tensor& normal_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
62 return at::native::templates::normal_impl_<native::templates::cpu::NormalKernel, TestCPUGenerator>(self, mean, std, gen);
63 }
64
normal_Tensor_float_out(const Tensor & mean,double std,std::optional<Generator> gen,Tensor & output)65 Tensor& normal_Tensor_float_out(const Tensor& mean, double std, std::optional<Generator> gen, Tensor& output) {
66 return at::native::templates::normal_out_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(output, mean, std, gen);
67 }
68
normal_float_Tensor_out(double mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)69 Tensor& normal_float_Tensor_out(double mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
70 return at::native::templates::normal_out_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(output, mean, std, gen);
71 }
72
normal_Tensor_Tensor_out(const Tensor & mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)73 Tensor& normal_Tensor_Tensor_out(const Tensor& mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
74 return at::native::templates::normal_out_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(output, mean, std, gen);
75 }
76
normal_Tensor_float(const Tensor & mean,double std,std::optional<Generator> gen)77 Tensor normal_Tensor_float(const Tensor& mean, double std, std::optional<Generator> gen) {
78 return at::native::templates::normal_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(mean, std, gen);
79 }
80
normal_float_Tensor(double mean,const Tensor & std,std::optional<Generator> gen)81 Tensor normal_float_Tensor(double mean, const Tensor& std, std::optional<Generator> gen) {
82 return at::native::templates::normal_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(mean, std, gen);
83 }
84
normal_Tensor_Tensor(const Tensor & mean,const Tensor & std,std::optional<Generator> gen)85 Tensor normal_Tensor_Tensor(const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
86 return at::native::templates::normal_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(mean, std, gen);
87 }
88
89 // ==================================================== Uniform =======================================================
90
uniform_(Tensor & self,double from,double to,std::optional<Generator> generator)91 Tensor& uniform_(Tensor& self, double from, double to, std::optional<Generator> generator) {
92 return at::native::templates::uniform_impl_<native::templates::cpu::UniformKernel, TestCPUGenerator>(self, from, to, generator);
93 }
94
95 // ==================================================== Cauchy ========================================================
96
cauchy_(Tensor & self,double median,double sigma,std::optional<Generator> generator)97 Tensor& cauchy_(Tensor& self, double median, double sigma, std::optional<Generator> generator) {
98 return at::native::templates::cauchy_impl_<native::templates::cpu::CauchyKernel, TestCPUGenerator>(self, median, sigma, generator);
99 }
100
101 // ================================================== LogNormal =======================================================
102
log_normal_(Tensor & self,double mean,double std,std::optional<Generator> gen)103 Tensor& log_normal_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
104 return at::native::templates::log_normal_impl_<native::templates::cpu::LogNormalKernel, TestCPUGenerator>(self, mean, std, gen);
105 }
106
107 // ================================================== Geometric =======================================================
108
geometric_(Tensor & self,double p,std::optional<Generator> gen)109 Tensor& geometric_(Tensor& self, double p, std::optional<Generator> gen) {
110 return at::native::templates::geometric_impl_<native::templates::cpu::GeometricKernel, TestCPUGenerator>(self, p, gen);
111 }
112
113 // ================================================== Exponential =====================================================
114
exponential_(Tensor & self,double lambda,std::optional<Generator> gen)115 Tensor& exponential_(Tensor& self, double lambda, std::optional<Generator> gen) {
116 return at::native::templates::exponential_impl_<native::templates::cpu::ExponentialKernel, TestCPUGenerator>(self, lambda, gen);
117 }
118
119 // ================================================== Bernoulli =======================================================
120
bernoulli_Tensor(Tensor & self,const Tensor & p_,std::optional<Generator> gen)121 Tensor& bernoulli_Tensor(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
122 return at::native::templates::bernoulli_impl_<native::templates::cpu::BernoulliKernel, TestCPUGenerator>(self, p_, gen);
123 }
124
bernoulli_float(Tensor & self,double p,std::optional<Generator> gen)125 Tensor& bernoulli_float(Tensor& self, double p, std::optional<Generator> gen) {
126 return at::native::templates::bernoulli_impl_<native::templates::cpu::BernoulliKernel, TestCPUGenerator>(self, p, gen);
127 }
128
bernoulli_out(const Tensor & self,std::optional<Generator> gen,Tensor & result)129 Tensor& bernoulli_out(const Tensor& self, std::optional<Generator> gen, Tensor& result) {
130 return at::native::templates::bernoulli_out_impl<native::templates::cpu::BernoulliKernel, TestCPUGenerator>(result, self, gen);
131 }
132
TORCH_LIBRARY_IMPL(aten,CustomRNGKeyId,m)133 TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
134 // Random
135 m.impl("random_.from", random_from_to);
136 m.impl("random_.to", random_to);
137 m.impl("random_", random_);
138 // Normal
139 m.impl("normal_", normal_);
140 m.impl("normal.Tensor_float_out", normal_Tensor_float_out);
141 m.impl("normal.float_Tensor_out", normal_float_Tensor_out);
142 m.impl("normal.Tensor_Tensor_out", normal_Tensor_Tensor_out);
143 m.impl("normal.Tensor_float", normal_Tensor_float);
144 m.impl("normal.float_Tensor", normal_float_Tensor);
145 m.impl("normal.Tensor_Tensor", normal_Tensor_Tensor);
146 m.impl("uniform_", uniform_);
147 // Cauchy
148 m.impl("cauchy_", cauchy_);
149 // LogNormal
150 m.impl("log_normal_", log_normal_);
151 // Geometric
152 m.impl("geometric_", geometric_);
153 // Exponential
154 m.impl("exponential_", exponential_);
155 // Bernoulli
156 m.impl("bernoulli.out", bernoulli_out);
157 m.impl("bernoulli_.Tensor", bernoulli_Tensor);
158 m.impl("bernoulli_.float", bernoulli_float);
159 }
160
161 class RNGTest : public ::testing::Test {
162 };
163
164 static constexpr auto MAGIC_NUMBER = 424242424242424242ULL;
165
166 // ==================================================== Random ========================================================
167
TEST_F(RNGTest,RandomFromTo)168 TEST_F(RNGTest, RandomFromTo) {
169 const at::Device device("cpu");
170 test_random_from_to<TestCPUGenerator, torch::kBool, bool>(device);
171 test_random_from_to<TestCPUGenerator, torch::kUInt8, uint8_t>(device);
172 test_random_from_to<TestCPUGenerator, torch::kInt8, int8_t>(device);
173 test_random_from_to<TestCPUGenerator, torch::kInt16, int16_t>(device);
174 test_random_from_to<TestCPUGenerator, torch::kInt32, int32_t>(device);
175 test_random_from_to<TestCPUGenerator, torch::kInt64, int64_t>(device);
176 test_random_from_to<TestCPUGenerator, torch::kFloat32, float>(device);
177 test_random_from_to<TestCPUGenerator, torch::kFloat64, double>(device);
178 }
179
TEST_F(RNGTest,Random)180 TEST_F(RNGTest, Random) {
181 const at::Device device("cpu");
182 test_random<TestCPUGenerator, torch::kBool, bool>(device);
183 test_random<TestCPUGenerator, torch::kUInt8, uint8_t>(device);
184 test_random<TestCPUGenerator, torch::kInt8, int8_t>(device);
185 test_random<TestCPUGenerator, torch::kInt16, int16_t>(device);
186 test_random<TestCPUGenerator, torch::kInt32, int32_t>(device);
187 test_random<TestCPUGenerator, torch::kInt64, int64_t>(device);
188 test_random<TestCPUGenerator, torch::kFloat32, float>(device);
189 test_random<TestCPUGenerator, torch::kFloat64, double>(device);
190 }
191
192 // This test proves that Tensor.random_() distribution is able to generate unsigned 64 bit max value(64 ones)
193 // https://github.com/pytorch/pytorch/issues/33299
TEST_F(RNGTest,Random64bits)194 TEST_F(RNGTest, Random64bits) {
195 auto gen = at::make_generator<TestCPUGenerator>(std::numeric_limits<uint64_t>::max());
196 auto actual = torch::empty({1}, torch::kInt64);
197 actual.random_(std::numeric_limits<int64_t>::min(), std::nullopt, gen);
198 ASSERT_EQ(static_cast<uint64_t>(actual[0].item<int64_t>()), std::numeric_limits<uint64_t>::max());
199 }
200
201 // ==================================================== Normal ========================================================
202
TEST_F(RNGTest,Normal)203 TEST_F(RNGTest, Normal) {
204 const auto mean = 123.45;
205 const auto std = 67.89;
206 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
207
208 auto actual = torch::empty({10});
209 actual.normal_(mean, std, gen);
210
211 auto expected = torch::empty_like(actual);
212 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
213
214 ASSERT_TRUE(torch::allclose(actual, expected));
215 }
216
TEST_F(RNGTest,Normal_float_Tensor_out)217 TEST_F(RNGTest, Normal_float_Tensor_out) {
218 const auto mean = 123.45;
219 const auto std = 67.89;
220 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
221
222 auto actual = torch::empty({10});
223 at::normal_out(actual, mean, torch::full({10}, std), gen);
224
225 auto expected = torch::empty_like(actual);
226 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
227
228 ASSERT_TRUE(torch::allclose(actual, expected));
229 }
230
TEST_F(RNGTest,Normal_Tensor_float_out)231 TEST_F(RNGTest, Normal_Tensor_float_out) {
232 const auto mean = 123.45;
233 const auto std = 67.89;
234 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
235
236 auto actual = torch::empty({10});
237 at::normal_out(actual, torch::full({10}, mean), std, gen);
238
239 auto expected = torch::empty_like(actual);
240 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
241
242 ASSERT_TRUE(torch::allclose(actual, expected));
243 }
244
TEST_F(RNGTest,Normal_Tensor_Tensor_out)245 TEST_F(RNGTest, Normal_Tensor_Tensor_out) {
246 const auto mean = 123.45;
247 const auto std = 67.89;
248 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
249
250 auto actual = torch::empty({10});
251 at::normal_out(actual, torch::full({10}, mean), torch::full({10}, std), gen);
252
253 auto expected = torch::empty_like(actual);
254 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
255
256 ASSERT_TRUE(torch::allclose(actual, expected));
257 }
258
TEST_F(RNGTest,Normal_float_Tensor)259 TEST_F(RNGTest, Normal_float_Tensor) {
260 const auto mean = 123.45;
261 const auto std = 67.89;
262 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
263
264 auto actual = at::normal(mean, torch::full({10}, std), gen);
265
266 auto expected = torch::empty_like(actual);
267 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
268
269 ASSERT_TRUE(torch::allclose(actual, expected));
270 }
271
TEST_F(RNGTest,Normal_Tensor_float)272 TEST_F(RNGTest, Normal_Tensor_float) {
273 const auto mean = 123.45;
274 const auto std = 67.89;
275 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
276
277 auto actual = at::normal(torch::full({10}, mean), std, gen);
278
279 auto expected = torch::empty_like(actual);
280 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
281
282 ASSERT_TRUE(torch::allclose(actual, expected));
283 }
284
TEST_F(RNGTest,Normal_Tensor_Tensor)285 TEST_F(RNGTest, Normal_Tensor_Tensor) {
286 const auto mean = 123.45;
287 const auto std = 67.89;
288 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
289
290 auto actual = at::normal(torch::full({10}, mean), torch::full({10}, std), gen);
291
292 auto expected = torch::empty_like(actual);
293 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
294
295 ASSERT_TRUE(torch::allclose(actual, expected));
296 }
297
298 // ==================================================== Uniform =======================================================
299
TEST_F(RNGTest,Uniform)300 TEST_F(RNGTest, Uniform) {
301 const auto from = -24.24;
302 const auto to = 42.42;
303 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
304
305 auto actual = torch::empty({3, 3});
306 actual.uniform_(from, to, gen);
307
308 auto expected = torch::empty_like(actual);
309 auto iter = TensorIterator::nullary_op(expected);
310 native::templates::cpu::uniform_kernel(iter, from, to, check_generator<TestCPUGenerator>(gen));
311
312 ASSERT_TRUE(torch::allclose(actual, expected));
313 }
314
315 // ==================================================== Cauchy ========================================================
316
TEST_F(RNGTest,Cauchy)317 TEST_F(RNGTest, Cauchy) {
318 const auto median = 123.45;
319 const auto sigma = 67.89;
320 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
321
322 auto actual = torch::empty({3, 3});
323 actual.cauchy_(median, sigma, gen);
324
325 auto expected = torch::empty_like(actual);
326 auto iter = TensorIterator::nullary_op(expected);
327 native::templates::cpu::cauchy_kernel(iter, median, sigma, check_generator<TestCPUGenerator>(gen));
328
329 ASSERT_TRUE(torch::allclose(actual, expected));
330 }
331
332 // ================================================== LogNormal =======================================================
333
TEST_F(RNGTest,LogNormal)334 TEST_F(RNGTest, LogNormal) {
335 const auto mean = 12.345;
336 const auto std = 6.789;
337 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
338
339 auto actual = torch::empty({10});
340 actual.log_normal_(mean, std, gen);
341
342 auto expected = torch::empty_like(actual);
343 auto iter = TensorIterator::nullary_op(expected);
344 native::templates::cpu::log_normal_kernel(iter, mean, std, check_generator<TestCPUGenerator>(gen));
345
346 ASSERT_TRUE(torch::allclose(actual, expected));
347 }
348
349 // ================================================== Geometric =======================================================
350
TEST_F(RNGTest,Geometric)351 TEST_F(RNGTest, Geometric) {
352 const auto p = 0.42;
353 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
354
355 auto actual = torch::empty({3, 3});
356 actual.geometric_(p, gen);
357
358 auto expected = torch::empty_like(actual);
359 auto iter = TensorIterator::nullary_op(expected);
360 native::templates::cpu::geometric_kernel(iter, p, check_generator<TestCPUGenerator>(gen));
361
362 ASSERT_TRUE(torch::allclose(actual, expected));
363 }
364
365 // ================================================== Exponential =====================================================
366
TEST_F(RNGTest,Exponential)367 TEST_F(RNGTest, Exponential) {
368 const auto lambda = 42;
369 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
370
371 auto actual = torch::empty({3, 3});
372 actual.exponential_(lambda, gen);
373
374 auto expected = torch::empty_like(actual);
375 auto iter = TensorIterator::nullary_op(expected);
376 native::templates::cpu::exponential_kernel(iter, lambda, check_generator<TestCPUGenerator>(gen));
377
378 ASSERT_TRUE(torch::allclose(actual, expected));
379 }
380
381 // ==================================================== Bernoulli =====================================================
382
TEST_F(RNGTest,Bernoulli_Tensor)383 TEST_F(RNGTest, Bernoulli_Tensor) {
384 const auto p = 0.42;
385 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
386
387 auto actual = torch::empty({3, 3});
388 actual.bernoulli_(torch::full({3,3}, p), gen);
389
390 auto expected = torch::empty_like(actual);
391 native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
392
393 ASSERT_TRUE(torch::allclose(actual, expected));
394 }
395
TEST_F(RNGTest,Bernoulli_scalar)396 TEST_F(RNGTest, Bernoulli_scalar) {
397 const auto p = 0.42;
398 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
399
400 auto actual = torch::empty({3, 3});
401 actual.bernoulli_(p, gen);
402
403 auto expected = torch::empty_like(actual);
404 native::templates::cpu::bernoulli_kernel(expected, p, check_generator<TestCPUGenerator>(gen));
405
406 ASSERT_TRUE(torch::allclose(actual, expected));
407 }
408
TEST_F(RNGTest,Bernoulli)409 TEST_F(RNGTest, Bernoulli) {
410 const auto p = 0.42;
411 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
412
413 auto actual = at::bernoulli(torch::full({3,3}, p), gen);
414
415 auto expected = torch::empty_like(actual);
416 native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
417
418 ASSERT_TRUE(torch::allclose(actual, expected));
419 }
420
TEST_F(RNGTest,Bernoulli_2)421 TEST_F(RNGTest, Bernoulli_2) {
422 const auto p = 0.42;
423 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
424
425 auto actual = torch::full({3,3}, p).bernoulli(gen);
426
427 auto expected = torch::empty_like(actual);
428 native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
429
430 ASSERT_TRUE(torch::allclose(actual, expected));
431 }
432
TEST_F(RNGTest,Bernoulli_p)433 TEST_F(RNGTest, Bernoulli_p) {
434 const auto p = 0.42;
435 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
436
437 auto actual = at::bernoulli(torch::empty({3, 3}), p, gen);
438
439 auto expected = torch::empty_like(actual);
440 native::templates::cpu::bernoulli_kernel(expected, p, check_generator<TestCPUGenerator>(gen));
441
442 ASSERT_TRUE(torch::allclose(actual, expected));
443 }
444
TEST_F(RNGTest,Bernoulli_p_2)445 TEST_F(RNGTest, Bernoulli_p_2) {
446 const auto p = 0.42;
447 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
448
449 auto actual = torch::empty({3, 3}).bernoulli(p, gen);
450
451 auto expected = torch::empty_like(actual);
452 native::templates::cpu::bernoulli_kernel(expected, p, check_generator<TestCPUGenerator>(gen));
453
454 ASSERT_TRUE(torch::allclose(actual, expected));
455 }
456
TEST_F(RNGTest,Bernoulli_out)457 TEST_F(RNGTest, Bernoulli_out) {
458 const auto p = 0.42;
459 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
460
461 auto actual = torch::empty({3, 3});
462 at::bernoulli_out(actual, torch::full({3,3}, p), gen);
463
464 auto expected = torch::empty_like(actual);
465 native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
466
467 ASSERT_TRUE(torch::allclose(actual, expected));
468 }
469 }
470 #endif // ATEN_CPU_STATIC_DISPATCH
471