• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Based on "Notes on generating Sobol sequences. August 2008" by Joe and Kuo.
17 // [1] https://web.maths.unsw.edu.au/~fkuo/sobol/joe-kuo-notes.pdf
18 #include <algorithm>
19 #include <cmath>
20 #include <cstdint>
21 #include <limits>
22 
23 #include "third_party/eigen3/Eigen/Core"
24 #include "sobol_data.h"
25 #include "tensorflow/core/framework/device_base.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/lib/core/threadpool.h"
28 #include "tensorflow/core/platform/platform_strings.h"
29 
30 namespace tensorflow {
31 
32 // Embed the platform strings in this binary.
33 TF_PLATFORM_STRINGS()
34 
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 
37 namespace {
38 
39 // Each thread will calculate at least kMinBlockSize points in the sequence.
40 constexpr int kMinBlockSize = 512;
41 
42 // Returns number of digits in binary representation of n.
43 // Example: n=13. Binary representation is 1101. NumBinaryDigits(13) -> 4.
NumBinaryDigits(int n)44 int NumBinaryDigits(int n) {
45   return static_cast<int>(std::log2(n) + 1);
46 }
47 
48 // Returns position of rightmost zero digit in binary representation of n.
49 // Example: n=13. Binary representation is 1101. RightmostZeroBit(13) -> 1.
RightmostZeroBit(int n)50 int RightmostZeroBit(int n) {
51   int k = 0;
52   while (n & 1) {
53     n >>= 1;
54     ++k;
55   }
56   return k;
57 }
58 
59 // Returns an integer representation of point `i` in the Sobol sequence of
60 // dimension `dim` using the given direction numbers.
GetFirstPoint(int i,int dim,const Eigen::MatrixXi & direction_numbers)61 Eigen::VectorXi GetFirstPoint(int i, int dim,
62                               const Eigen::MatrixXi& direction_numbers) {
63   // Index variables used in this function, consistent with notation in [1].
64   // i - point in the Sobol sequence
65   // j - dimension
66   // k - binary digit
67   Eigen::VectorXi integer_sequence = Eigen::VectorXi::Zero(dim);
68   // go/wiki/Sobol_sequence#A_fast_algorithm_for_the_construction_of_Sobol_sequences
69   int gray_code = i ^ (i >> 1);
70   int num_digits = NumBinaryDigits(i);
71   for (int j = 0; j < dim; ++j) {
72     for (int k = 0; k < num_digits; ++k) {
73       if ((gray_code >> k) & 1) integer_sequence(j) ^= direction_numbers(j, k);
74     }
75   }
76   return integer_sequence;
77 }
78 
79 // Calculates `num_results` Sobol points of dimension `dim` starting at the
80 // point `start_point + skip` and writes them into `output` starting at point
81 // `start_point`.
82 template <typename T>
CalculateSobolSample(int32_t dim,int32_t num_results,int32_t skip,int32_t start_point,typename TTypes<T>::Flat output)83 void CalculateSobolSample(int32_t dim, int32_t num_results, int32_t skip,
84                           int32_t start_point,
85                           typename TTypes<T>::Flat output) {
86   // Index variables used in this function, consistent with notation in [1].
87   // i - point in the Sobol sequence
88   // j - dimension
89   // k - binary digit
90   const int num_digits =
91       NumBinaryDigits(skip + start_point + num_results + 1);
92   Eigen::MatrixXi direction_numbers(dim, num_digits);
93 
94   // Shift things so we can use integers everywhere. Before we write to output,
95   // divide by constant to convert back to floats.
96   const T normalizing_constant = 1./(1 << num_digits);
97   for (int j = 0; j < dim; ++j) {
98     for (int k = 0; k < num_digits; ++k) {
99       direction_numbers(j, k) = sobol_data::kDirectionNumbers[j][k]
100                                 << (num_digits - k - 1);
101     }
102   }
103 
104   // If needed, skip ahead to the appropriate point in the sequence. Otherwise
105   // we start with the first column of direction numbers.
106   Eigen::VectorXi integer_sequence =
107       (skip + start_point > 0)
108           ? GetFirstPoint(skip + start_point + 1, dim, direction_numbers)
109           : direction_numbers.col(0);
110 
111   for (int j = 0; j < dim; ++j) {
112     output(start_point * dim + j) = integer_sequence(j) * normalizing_constant;
113   }
114   // go/wiki/Sobol_sequence#A_fast_algorithm_for_the_construction_of_Sobol_sequences
115   for (int i = start_point + 1; i < num_results + start_point; ++i) {
116     // The Gray code for the current point differs from the preceding one by
117     // just a single bit -- the rightmost bit.
118     int k = RightmostZeroBit(i + skip);
119     // Update the current point from the preceding one with a single XOR
120     // operation per dimension.
121     for (int j = 0; j < dim; ++j) {
122       integer_sequence(j) ^= direction_numbers(j, k);
123       output(i * dim + j) = integer_sequence(j) * normalizing_constant;
124     }
125   }
126 }
127 
128 }  // namespace
129 
130 template <typename Device, typename T>
131 class SobolSampleOp : public OpKernel {
132  public:
SobolSampleOp(OpKernelConstruction * context)133   explicit SobolSampleOp(OpKernelConstruction* context)
134       : OpKernel(context) {}
135 
Compute(OpKernelContext * context)136   void Compute(OpKernelContext* context) override {
137     int32_t dim = context->input(0).scalar<int32_t>()();
138     int32_t num_results = context->input(1).scalar<int32_t>()();
139     int32_t skip = context->input(2).scalar<int32_t>()();
140 
141     OP_REQUIRES(context, dim >= 1,
142                 errors::InvalidArgument("dim must be at least one"));
143     OP_REQUIRES(context, dim <= sobol_data::kMaxSobolDim,
144                 errors::InvalidArgument("dim must be at most ",
145                                         sobol_data::kMaxSobolDim));
146     OP_REQUIRES(context, num_results >= 1,
147                 errors::InvalidArgument("num_results must be at least one"));
148     OP_REQUIRES(context, skip >= 0,
149                 errors::InvalidArgument("skip must be non-negative"));
150     OP_REQUIRES(context,
151                 num_results < std::numeric_limits<int32_t>::max() - skip,
152                 errors::InvalidArgument("num_results+skip must be less than ",
153                                         std::numeric_limits<int32_t>::max()));
154 
155     Tensor* output = nullptr;
156     OP_REQUIRES_OK(context,
157                    context->allocate_output(
158                        0, TensorShape({num_results, dim}), &output));
159     auto output_flat = output->flat<T>();
160     const DeviceBase::CpuWorkerThreads& worker_threads =
161         *(context->device()->tensorflow_cpu_worker_threads());
162     int num_threads = worker_threads.num_threads;
163     int block_size = std::max(
164         kMinBlockSize, static_cast<int>(std::ceil(
165                            static_cast<float>(num_results) / num_threads)));
166     worker_threads.workers->TransformRangeConcurrently(
167         block_size, num_results /* total */,
168         [&dim, &skip, &output_flat](const int start, const int end) {
169           CalculateSobolSample<T>(dim, end - start /* num_results */, skip,
170                                   start, output_flat);
171         });
172   }
173 };
174 
175 REGISTER_KERNEL_BUILDER(
176     Name("SobolSample").Device(DEVICE_CPU).TypeConstraint<double>("dtype"),
177     SobolSampleOp<CPUDevice, double>);
178 REGISTER_KERNEL_BUILDER(
179     Name("SobolSample").Device(DEVICE_CPU).TypeConstraint<float>("dtype"),
180     SobolSampleOp<CPUDevice, float>);
181 
182 }  // namespace tensorflow
183