• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA implementations of Random ops
17 // TODO(misard,phawkins): handle random number generator seeds/states correctly.
18 // TODO(misard,phawkins): add tests.
19 
20 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
21 #include "tensorflow/compiler/tf2xla/lib/random.h"
22 #include "tensorflow/compiler/tf2xla/lib/util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
28 #include "tensorflow/compiler/xla/client/lib/constants.h"
29 #include "tensorflow/compiler/xla/client/lib/loops.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 
35 namespace tensorflow {
36 namespace {
37 
38 class RandomUniformOp : public XlaOpKernel {
39  public:
RandomUniformOp(OpKernelConstruction * ctx)40   explicit RandomUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
41 
Compile(XlaOpKernelContext * ctx)42   void Compile(XlaOpKernelContext* ctx) override {
43     TensorShape shape;
44     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
45 
46     const DataType dtype = output_type(0);
47     xla::Shape xla_shape;
48     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
49 
50     xla::XlaBuilder* b = ctx->builder();
51     xla::XlaOp result = xla::RngUniform(XlaHelpers::Zero(b, dtype),
52                                         XlaHelpers::One(b, dtype), xla_shape);
53 
54     ctx->SetOutput(0, result);
55   }
56 
57  private:
58   TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp);
59 };
60 
61 REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstantInput("shape"),
62                 RandomUniformOp);
63 
64 class RandomShuffleOp : public XlaOpKernel {
65  public:
RandomShuffleOp(OpKernelConstruction * ctx)66   explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
67 
Compile(XlaOpKernelContext * ctx)68   void Compile(XlaOpKernelContext* ctx) override {
69     auto builder = ctx->builder();
70     xla::XlaOp input = ctx->Input(0);
71     TensorShape input_shape = ctx->InputShape(0);
72     const int64 n = input_shape.dim_size(0);
73     int64 num_elements = 1;
74     for (tensorflow::TensorShapeDim dimension : input_shape) {
75       num_elements *= dimension.size;
76     }
77 
78     if (num_elements <= 1 || n <= 1) {
79       // No shuffling is required, so copy input directly to output
80       ctx->SetOutput(0, input);
81       return;
82     }
83 
84     if (input_shape.dims() == 1) {
85       // For R1s, shuffle values by sorting instead of the obvious Fisher-Yates
86       // algorithm. Fisher-Yates is simple to implement and correct, but not
87       // easily parallelizable. For a sufficiently parallel architecture, it is
88       // faster to sort many times, than Fisher-Yates shuffle once.
89 
90       // Shuffle values by assigning each value a random key and sorting the
91       // keys. Keys can collide causing detectable patterns in the shuffled
92       // output. Collisions translates into more ascending sub-sequences in the
93       // shuffled output than would be expected by chance. To avoid collisions,
94       // the number of possible key values must be sufficiently large.
95 
96       // How are more than 2^32 keys created? In each loop iteration, the
97       // algorithm sorts by random keys. Conceptually, the earlier iterations
98       // are sorting on the lower-order bits of larger keys that are never
99       // actually assembled.
100 
101       // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
102       // the number of possible keys and n is the number of values. If d = n^2,
103       // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
104       // as n goes to infinity is zero.
105 
106       // This implementation ensures that the key-space is greater than or equal
107       // to the cube of the number of values. The risk of collisions can be
108       // further reduced by increasing Exponent at the expense of
109       // performance.
110 
111       // For Exponent = 2, the expected number of collisions per shuffle is
112       // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
113       // about 1/2.
114 
115       // For Exponent = 3, the expected number of collisions per shuffle is
116       // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
117       // about 1/3255.
118 
119       // For Exponent = 4, the expected number of collisions per shuffle is
120       // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
121       // about 1/132622.
122       constexpr int Exponent = 3;
123       const int rounds = static_cast<int>(
124           std::ceil(Exponent * std::log(num_elements) / std::log(kuint32max)));
125 
126       const xla::Shape key_shape =
127           xla::ShapeUtil::MakeShape(xla::U32, {num_elements});
128       xla::XlaOp zero = xla::ConstantR0(builder, 0U);
129 
130       // Unfortunately, xla::RngUniform gives values in the half open interval
131       // rather than the closed interval, so instead of 2^32 possible keys there
132       // are only 2^32 - 1 (kuint32max).
133       xla::XlaOp max_value = xla::ConstantR0(builder, kuint32max);
134 
135       xla::XlaOp curr = input;
136       for (int i = 0; i < rounds; ++i) {
137         xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape);
138         xla::XlaOp sorted = xla::Sort(keys, {curr});
139         curr = xla::GetTupleElement(sorted, 1);
140       }
141 
142       ctx->SetOutput(0, curr);
143       return;
144     }
145 
146     // The Fisher-Yates algorithm.
147 
148     // Generate the random swaps for the indices.
149     auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
150     auto swaps =
151         xla::RngUniform(xla::ConstantR0<int32>(builder, 0),
152                         xla::ConstantR0<int32>(builder, n), swaps_shape);
153 
154     // Generate range(n) as the initial value for the indices to be swapped.
155     xla::XlaOp indices = xla::Iota(builder, xla::S32, n);
156 
157     // Swap the indices at i and swaps[i].
158     auto swap_body_fn = [&](xla::XlaOp i,
159                             absl::Span<const xla::XlaOp> loop_vars,
160                             xla::XlaBuilder* builder)
161         -> xla::StatusOr<std::vector<xla::XlaOp>> {
162       auto swaps = loop_vars[0];
163       auto indices = loop_vars[1];
164       // TODO(b/118437727): The absl::Span nonsense is only necessary because
165       // the deprecated overload creates ambiguity for the single-element span
166       // case. Remove it once the deprecated overload is gone.
167       // temp = indices[i]
168       auto temp =
169           xla::DynamicSlice(indices, absl::Span<const xla::XlaOp>({i}), {1});
170       // swap_index = swaps[i]
171       auto swap_index = xla::Reshape(
172           xla::DynamicSlice(swaps, absl::Span<const xla::XlaOp>({i}), {1}), {});
173       // swap_value = indices[swaps[i]]
174       auto swap_value = xla::DynamicSlice(
175           indices, absl::Span<const xla::XlaOp>({swap_index}), {1});
176       // indices[i] = indices[swaps[i]]
177       indices = xla::DynamicUpdateSlice(indices, swap_value,
178                                         absl::Span<const xla::XlaOp>({i}));
179       // indices[swaps[i]] = temp
180       indices = xla::DynamicUpdateSlice(
181           indices, temp, absl::Span<const xla::XlaOp>({swap_index}));
182       return std::vector<xla::XlaOp>{swaps, indices};
183     };
184     // for i in range(n):
185     auto swap_loop_result =
186         xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
187                           "indices_swap_loop", builder)
188             .ValueOrDie();
189     auto swapped_indices = swap_loop_result[1];
190 
191     // Gather the data using the swapped indices as the shuffled order.
192     auto indices_tensor_shape = TensorShape({n});
193     DataType type = ctx->expected_output_dtype(0);
194     xla::XlaOp gather;
195     OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
196                                   indices_tensor_shape,
197                                   /*axis=*/0, /*indices_are_nd=*/false, type,
198                                   DT_INT32, builder, &gather));
199     ctx->SetOutput(0, gather);
200   }
201 
202  private:
203   TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp);
204 };
205 
206 REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp);
207 
208 class RandomUniformIntOp : public XlaOpKernel {
209  public:
RandomUniformIntOp(OpKernelConstruction * ctx)210   explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
211 
Compile(XlaOpKernelContext * ctx)212   void Compile(XlaOpKernelContext* ctx) override {
213     TensorShape shape;
214     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
215     xla::Shape xla_shape;
216     OP_REQUIRES_OK(ctx,
217                    TensorShapeToXLAShape(input_type(1), shape, &xla_shape));
218 
219     const TensorShape minval_shape = ctx->InputShape(1);
220     const TensorShape maxval_shape = ctx->InputShape(2);
221     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
222                 errors::InvalidArgument("minval must be 0-D, got shape ",
223                                         minval_shape.DebugString()));
224     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
225                 errors::InvalidArgument("maxval must be 0-D, got shape ",
226                                         maxval_shape.DebugString()));
227 
228     auto minval = ctx->Input(1);
229     auto maxval = ctx->Input(2);
230     ctx->SetOutput(0, xla::RngUniform(minval, maxval, xla_shape));
231   }
232 
233  private:
234   TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp);
235 };
236 
237 REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstantInput("shape"),
238                 RandomUniformIntOp);
239 
240 class RandomStandardNormalOp : public XlaOpKernel {
241  public:
RandomStandardNormalOp(OpKernelConstruction * ctx)242   explicit RandomStandardNormalOp(OpKernelConstruction* ctx)
243       : XlaOpKernel(ctx) {}
244 
Compile(XlaOpKernelContext * ctx)245   void Compile(XlaOpKernelContext* ctx) override {
246     const DataType dtype = output_type(0);
247 
248     TensorShape shape;
249     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
250     xla::Shape xla_shape;
251     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
252 
253     xla::XlaBuilder* b = ctx->builder();
254 
255     // Normal distribution with a mean of 0 and a standard deviation of 1:
256     xla::XlaOp result = xla::RngNormal(XlaHelpers::Zero(b, dtype),
257                                        XlaHelpers::One(b, dtype), xla_shape);
258 
259     ctx->SetOutput(0, result);
260   }
261 
262  private:
263   TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp);
264 };
265 
266 REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstantInput("shape"),
267                 RandomStandardNormalOp);
268 
269 class TruncatedNormalOp : public XlaOpKernel {
270  public:
TruncatedNormalOp(OpKernelConstruction * ctx)271   explicit TruncatedNormalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
272 
Compile(XlaOpKernelContext * ctx)273   void Compile(XlaOpKernelContext* ctx) override {
274     const DataType dtype = output_type(0);
275 
276     TensorShape shape;
277     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
278     xla::Shape xla_shape;
279     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
280 
281     xla::XlaBuilder* b = ctx->builder();
282 
283     xla::XlaOp one = xla::One(b, xla_shape.element_type());
284     xla::XlaOp min_positive =
285         xla::MinPositiveNormalValue(b, xla_shape.element_type());
286     auto uniform = xla::RngUniform(min_positive, one, xla_shape);
287     ctx->SetOutput(0, TruncatedNormal(uniform));
288   }
289 };
290 
291 REGISTER_XLA_OP(Name("TruncatedNormal")
292                     .CompileTimeConstantInput("shape")
293                     .TypeConstraint("dtype", DT_FLOAT),
294                 TruncatedNormalOp);
295 
296 }  // namespace
297 }  // namespace tensorflow
298