#include #include #include #include #include #include #include #include namespace { constexpr auto int64_min_val = std::numeric_limits::lowest(); constexpr auto int64_max_val = std::numeric_limits::max(); template , int> = 0> constexpr int64_t _min_val() { return int64_min_val; } template , int> = 0> constexpr int64_t _min_val() { return static_cast(std::numeric_limits::lowest()); } template , int> = 0> constexpr int64_t _min_from() { return -(static_cast(1) << std::numeric_limits::digits); } template , int> = 0> constexpr int64_t _min_from() { return _min_val(); } template , int> = 0> constexpr int64_t _max_val() { return int64_max_val; } template , int> = 0> constexpr int64_t _max_val() { return static_cast(std::numeric_limits::max()); } template , int> = 0> constexpr int64_t _max_to() { return static_cast(1) << std::numeric_limits::digits; } template , int> = 0> constexpr int64_t _max_to() { return _max_val(); } template void test_random_from_to(const at::Device& device) { constexpr int64_t max_val = _max_val(); constexpr int64_t max_to = _max_to(); constexpr auto uint64_max_val = std::numeric_limits::max(); std::vector froms; std::vector<::std::optional> tos; if constexpr (::std::is_same_v) { froms = { 0L }; tos = { 1L, static_cast<::std::optional>(::std::nullopt) }; } else if constexpr (::std::is_signed_v) { constexpr int64_t min_from = _min_from(); froms = { min_from, -42L, 0L, 42L }; tos = { ::std::optional(-42L), ::std::optional(0L), ::std::optional(42L), ::std::optional(max_to), static_cast<::std::optional>(::std::nullopt) }; } else { froms = { 0L, 42L }; tos = { ::std::optional(42L), ::std::optional(max_to), static_cast<::std::optional>(::std::nullopt) }; } const std::vector vals = { 0L, 42L, static_cast(max_val), static_cast(max_val) + 1, uint64_max_val }; bool full_64_bit_range_case_covered = false; bool from_to_case_covered = false; bool from_case_covered = false; for (const int64_t from : froms) { for (const ::std::optional & to : tos) { if (!to.has_value() || from < *to) { for (const uint64_t val : vals) { auto gen = at::make_generator(val); auto actual = torch::empty({3, 3}, torch::TensorOptions().dtype(S).device(device)); actual.random_(from, to, gen); T exp; uint64_t range; if (!to.has_value() && from == int64_min_val) { exp = static_cast(val); full_64_bit_range_case_covered = true; } else { if (to.has_value()) { range = static_cast(*to) - static_cast(from); from_to_case_covered = true; } else { range = static_cast(max_to) - static_cast(from) + 1; from_case_covered = true; } if (range < (1ULL << 32)) { exp = static_cast(static_cast((static_cast(val) % range + from))); } else { exp = static_cast(static_cast((val % range + from))); } } ASSERT_TRUE(from <= exp); if (to.has_value()) { ASSERT_TRUE(static_cast(exp) < *to); } const auto expected = torch::full_like(actual, exp); if constexpr (::std::is_same_v) { ASSERT_TRUE(torch::allclose(actual.toType(torch::kInt), expected.toType(torch::kInt))); } else { ASSERT_TRUE(torch::allclose(actual, expected)); } } } } } if constexpr (::std::is_same_v) { ASSERT_TRUE(full_64_bit_range_case_covered); } else { (void)full_64_bit_range_case_covered; } ASSERT_TRUE(from_to_case_covered); ASSERT_TRUE(from_case_covered); } template void test_random(const at::Device& device) { const auto max_val = _max_val(); const auto uint64_max_val = std::numeric_limits::max(); const std::vector vals = { 0L, 42L, static_cast(max_val), static_cast(max_val) + 1, uint64_max_val }; for (const uint64_t val : vals) { auto gen = at::make_generator(val); auto actual = torch::empty({3, 3}, torch::TensorOptions().dtype(S).device(device)); actual.random_(gen); uint64_t range; if constexpr (::std::is_floating_point_v) { range = static_cast((1ULL << ::std::numeric_limits::digits) + 1); } else if constexpr (::std::is_same_v) { range = 2; } else { range = static_cast(::std::numeric_limits::max()) + 1; } T exp; if constexpr (::std::is_same_v || ::std::is_same_v) { exp = val % range; } else { exp = static_cast(val) % range; } ASSERT_TRUE(0 <= static_cast(exp)); ASSERT_TRUE(static_cast(exp) < range); const auto expected = torch::full_like(actual, exp); if constexpr (::std::is_same_v) { ASSERT_TRUE(torch::allclose(actual.toType(torch::kInt), expected.toType(torch::kInt))); } else { ASSERT_TRUE(torch::allclose(actual, expected)); } } } }