1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 
5 #include <ATen/native/SobolEngineOpsUtils.h>
6 #include <c10/util/irange.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/_sobol_engine_draw_native.h>
13 #include <ATen/ops/_sobol_engine_ff_native.h>
14 #include <ATen/ops/_sobol_engine_initialize_state_native.h>
15 #include <ATen/ops/_sobol_engine_scramble_native.h>
16 #include <ATen/ops/arange_native.h>
17 #include <ATen/ops/empty.h>
18 #endif
19 
20 namespace at::native {
21 
22 using namespace sobol_utils;
23 
24 /// This is the core function to draw samples from a `SobolEngine` given
25 /// its state variables (`sobolstate` and `quasi`). `dimension` can be
26 /// inferred from `sobolstate`, but choosing to pass it explicitly to avoid
27 /// an extra operation to obtain the size of the first dimension of
28 /// `sobolstate`.
_sobol_engine_draw(const Tensor & quasi,int64_t n,const Tensor & sobolstate,int64_t dimension,int64_t num_generated,std::optional<ScalarType> dtype)29 std::tuple<Tensor, Tensor> _sobol_engine_draw(const Tensor& quasi, int64_t n, const Tensor& sobolstate,
30                                               int64_t dimension, int64_t num_generated, std::optional<ScalarType> dtype) {
31   TORCH_CHECK(sobolstate.dtype() == at::kLong,
32            "sobolstate needs to be of type ", at::kLong);
33   TORCH_CHECK(quasi.dtype() == at::kLong,
34            "quasi needs to be of type ", at::kLong);
35 
36   Tensor wquasi = quasi.clone(at::MemoryFormat::Contiguous);
37   auto result_dtype = dtype.has_value() ? dtype.value() : at::kFloat;
38   Tensor result = at::empty({n, dimension}, sobolstate.options().dtype(result_dtype));
39 
40   AT_DISPATCH_FLOATING_TYPES(result_dtype, "_sobol_engine_draw", [&]() -> void {
41     // We deal with `data` and `strides` due to performance issues.
42     int64_t l;
43     int64_t* wquasi_data = wquasi.data_ptr<int64_t>();
44     int64_t* sobolstate_data = sobolstate.data_ptr<int64_t>();
45     scalar_t* result_data = result.data_ptr<scalar_t>();
46 
47     int64_t wquasi_stride = wquasi.stride(0);
48     int64_t sobolstate_row_stride = sobolstate.stride(0), sobolstate_col_stride = sobolstate.stride(1);
49     int64_t result_row_stride = result.stride(0), result_col_stride = result.stride(1);
50 
51     for (int64_t i = 0; i < n; i++, num_generated++) {
52       l = rightmost_zero(num_generated);
53       for (const auto j : c10::irange(dimension)) {
54         wquasi_data[j * wquasi_stride] ^= sobolstate_data[j * sobolstate_row_stride + l * sobolstate_col_stride];
55         result_data[i * result_row_stride + j * result_col_stride] = wquasi_data[j * wquasi_stride];
56       }
57     }
58   });
59 
60   result.mul_(RECIPD);
61   return std::tuple<Tensor, Tensor>(result, wquasi);
62 }
63 
64 /// This is the core function to fast-forward a `SobolEngine` given
65 /// its state variables (`sobolstate` and `quasi`). `dimension` can be
66 /// inferred from `sobolstate`, but is passed as an argument for the same reasons
67 /// specified above.
_sobol_engine_ff_(Tensor & quasi,int64_t n,const Tensor & sobolstate,int64_t dimension,int64_t num_generated)68 Tensor& _sobol_engine_ff_(Tensor& quasi, int64_t n, const Tensor& sobolstate,
69                         int64_t dimension, int64_t num_generated) {
70   TORCH_CHECK(sobolstate.dtype() == at::kLong,
71            "sobolstate needs to be of type ", at::kLong);
72   TORCH_CHECK(quasi.dtype() == at::kLong,
73            "quasi needs to be of type ", at::kLong);
74 
75   // We deal with `data` and `strides` due to performance issues.
76   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
77   int64_t l;
78   int64_t* quasi_data = quasi.data_ptr<int64_t>();
79   int64_t* sobolstate_data = sobolstate.data_ptr<int64_t>();
80 
81   int64_t quasi_stride = quasi.stride(0);
82   int64_t sobolstate_row_stride = sobolstate.stride(0), sobolstate_col_stride = sobolstate.stride(1);
83 
84   for (int64_t i = 0; i < n; i++, num_generated++) {
85     l = rightmost_zero(num_generated);
86     for (const auto j : c10::irange(dimension)) {
87       quasi_data[j * quasi_stride] ^= sobolstate_data[j * sobolstate_row_stride + l * sobolstate_col_stride];
88     }
89   }
90   return quasi;
91 }
92 
93 /// This is an implicit function used for randomizing the state variables of the.
94 /// `SobolEngine`. Arguments are a randomized `sobolstate` state variables
95 /// and a list of random lower triangular matrices consisting of 0s and 1s. `dimension` is
96 /// passed explicitly again.
_sobol_engine_scramble_(Tensor & sobolstate,const Tensor & ltm,int64_t dimension)97 Tensor& _sobol_engine_scramble_(Tensor& sobolstate, const Tensor& ltm, int64_t dimension) {
98   TORCH_CHECK(sobolstate.dtype() == at::kLong,
99            "sobolstate needs to be of type ", at::kLong);
100 
101   /// Require a tensor accessor for `sobolstate`
102   auto ss_a = sobolstate.accessor<int64_t, 2>();
103 
104   /// For every tensor in the list of tensors, the diagonals are made 1
105   /// Require a dot product of every row with a specific vector of each of the matrices in `ltm`.
106   /// Instead, we perform an element-wise product of all the matrices and sum over the last dimension.
107   /// The required product of the m^{th} row in the d^{th} square matrix in `ltm` can be accessed
108   /// using ltm_d_a[d][m] m and d are zero-indexed
109   Tensor diag_true = ltm.clone(at::MemoryFormat::Contiguous);
110   diag_true.diagonal(0, -2, -1).fill_(1);
111   Tensor ltm_dots = cdot_pow2(diag_true);
112   auto ltm_d_a = ltm_dots.accessor<int64_t, 2>();
113 
114   /// Main scrambling loop
115   for (const auto d : c10::irange(dimension)) {
116     for (const auto j : c10::irange(MAXBIT)) {
117       int64_t vdj = ss_a[d][j], l = 1, t2 = 0;
118       for (int64_t p = MAXBIT - 1; p >= 0; --p) {
119         int64_t lsmdp = ltm_d_a[d][p];
120         int64_t t1 = 0;
121         for (const auto k : c10::irange(MAXBIT)) {
122           t1 += (bitsubseq(lsmdp, k, 1) * bitsubseq(vdj, k, 1));
123         }
124         t1 = t1 % 2;
125         t2 = t2 + t1 * l;
126         l = l << 1;
127       }
128       ss_a[d][j] = t2;
129     }
130   }
131   return sobolstate;
132 }
133 
134 /// This is a core function to initialize the main state variable of a `SobolEngine`.
135 /// `dimension` is passed explicitly as well (see why above)
_sobol_engine_initialize_state_(Tensor & sobolstate,int64_t dimension)136 Tensor& _sobol_engine_initialize_state_(Tensor& sobolstate, int64_t dimension) {
137   TORCH_CHECK(sobolstate.dtype() == at::kLong,
138            "sobolstate needs to be of type ", at::kLong);
139 
140   /// Use a tensor accessor for `sobolstate`
141   auto ss_a = sobolstate.accessor<int64_t, 2>();
142 
143   /// First row of `sobolstate` is all 1s
144   for (const auto m : c10::irange(MAXBIT)) {
145     ss_a[0][m] = 1;
146   }
147 
148   /// Remaining rows of sobolstate (row 2 through dim, indexed by [1:dim])
149   for (const auto d : c10::irange(1, dimension)) {
150     int64_t p = poly[d];
151     int64_t m = bit_length(p) - 1;
152 
153     // First m elements of row d comes from initsobolstate
154     for (const auto i : c10::irange(m)) {
155       ss_a[d][i] = initsobolstate[d][i];
156     }
157 
158     // Fill in remaining elements of v as in Section 2 (top of pg. 90) of:
159     // P. Bratley and B. L. Fox. Algorithm 659: Implementing sobol's
160     // quasirandom sequence generator. ACM Trans.
161     // Math. Softw., 14(1):88-100, Mar. 1988.
162     for (const auto j : c10::irange(m, MAXBIT)) {
163       int64_t newv = ss_a[d][j - m];
164       int64_t pow2 = 1;
165       for (const auto k : c10::irange(m)) {
166         pow2 <<= 1;
167         if ((p >> (m - 1 - k)) & 1) {
168           newv = newv ^ (pow2 * ss_a[d][j - k - 1]);
169         }
170       }
171       ss_a[d][j] = newv;
172     }
173   }
174 
175   /// Multiply each column of sobolstate by power of 2:
176   /// sobolstate * [2^(maxbit-1), 2^(maxbit-2),..., 2, 1]
177   Tensor pow2s = at::pow(
178       2,
179       at::native::arange(
180           (MAXBIT - 1),
181           -1,
182           -1,
183           optTypeMetaToScalarType(sobolstate.options().dtype_opt()),
184           sobolstate.options().layout_opt(),
185           sobolstate.options().device_opt(),
186           sobolstate.options().pinned_memory_opt()));
187   sobolstate.mul_(pow2s);
188   return sobolstate;
189 }
190 
191 } // namespace at::native
192