1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/stateful_random_ops.h"
17
18 #include <cmath>
19 #include <utility>
20
21 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
22 #include "tensorflow/compiler/tf2xla/lib/random.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/lib/constants.h"
29 #include "tensorflow/compiler/xla/client/lib/math.h"
30 #include "tensorflow/compiler/xla/client/lib/prng.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/rng_alg.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/tensor_shape.h"
37 #include "tensorflow/core/lib/math/math_util.h"
38
39 namespace tensorflow {
40 namespace {
41
BitGen(Algorithm alg)42 xla::BitGeneratorTy BitGen(Algorithm alg) {
43 if (alg == RNG_ALG_PHILOX) {
44 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
45 state =
46 xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
47 xla::XlaOp result =
48 xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, state, shape);
49 xla::XlaOp data = xla::GetTupleElement(result, 1);
50 xla::XlaOp new_state =
51 xla::Slice(xla::GetTupleElement(result, 0), {1}, {3}, {1});
52 return xla::RngOutput{data, new_state};
53 };
54 } else {
55 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
56 state = xla::ConcatScalars(key.builder(), {key, state});
57 xla::XlaOp result = xla::RngBitGenerator(
58 xla::RandomAlgorithm::RNG_THREE_FRY, state, shape);
59 xla::XlaOp data = xla::GetTupleElement(result, 1);
60 xla::XlaOp new_state = xla::Reshape(
61 xla::Slice(xla::GetTupleElement(result, 0), {1}, {2}, {1}), {});
62 return xla::RngOutput{data, new_state};
63 };
64 }
65 }
66
StatefulRngUniform(Algorithm alg,xla::XlaOp key,xla::XlaOp initial_state,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)67 xla::RngOutput StatefulRngUniform(Algorithm alg, xla::XlaOp key,
68 xla::XlaOp initial_state,
69 const xla::Shape& shape, xla::XlaOp minval,
70 xla::XlaOp maxval) {
71 xla::PrimitiveType type = shape.element_type();
72 switch (type) {
73 case xla::F32:
74 case xla::F64:
75 return xla::UniformFloatingPointDistribution(
76 key, initial_state, BitGen(alg), minval, maxval, shape);
77 case xla::U32:
78 case xla::S32:
79 case xla::U64:
80 case xla::S64:
81 return UniformIntDistribution(key, initial_state, BitGen(alg), minval,
82 maxval, shape);
83 default:
84 return {key.builder()->ReportError(xla::Unimplemented(
85 "Types other than F32, U32, S32, U64 and S64 "
86 "are not implemented by "
87 "StatefulRngUniform; got %s",
88 xla::primitive_util::LowercasePrimitiveTypeName(type))),
89 initial_state};
90 }
91 }
92
StatefulRngUniformFullInt(Algorithm alg,xla::XlaOp key,xla::XlaOp initial_state,const xla::Shape & shape)93 xla::RngOutput StatefulRngUniformFullInt(Algorithm alg, xla::XlaOp key,
94 xla::XlaOp initial_state,
95 const xla::Shape& shape) {
96 xla::PrimitiveType type = shape.element_type();
97 xla::RngOutput output = BitGen(alg)(key, initial_state, shape);
98 switch (type) {
99 case xla::U32:
100 case xla::U64:
101 return output;
102 case xla::S32:
103 case xla::S64:
104 output.value = BitcastConvertType(output.value, type);
105 return output;
106 default:
107 return {
108 key.builder()->ReportError(xla::Unimplemented(
109 "Types other than U32, S32, U64 and S64 are not implemented by "
110 "StatefulRngUniformFullInt; got: %s",
111 xla::primitive_util::LowercasePrimitiveTypeName(type))),
112 initial_state};
113 }
114 }
115
116 using SamplerReturnType = StatusOr<xla::RngOutput>;
117
GetMinStateSize(Algorithm alg)118 int64_t GetMinStateSize(Algorithm alg) {
119 if (alg == RNG_ALG_PHILOX) {
120 return PHILOX_MIN_STATE_SIZE;
121 }
122 return THREEFRY_MIN_STATE_SIZE;
123 }
124
CheckStateShape(Algorithm alg,const TensorShape & shape)125 Status CheckStateShape(Algorithm alg, const TensorShape& shape) {
126 if (shape.dims() != 1) {
127 return errors::InvalidArgument(
128 "RNG state must have one and only one dimension, not ", shape.dims());
129 }
130 auto state_size = shape.dim_size(0);
131 auto min_state_size = GetMinStateSize(alg);
132 if (state_size < min_state_size) {
133 return errors::InvalidArgument("The size of the state must be at least ",
134 min_state_size, "; got ", state_size);
135 }
136 return OkStatus();
137 }
138
StateAndKeyFromVariable(Algorithm alg,xla::XlaOp var)139 std::pair<xla::XlaOp, xla::XlaOp> StateAndKeyFromVariable(Algorithm alg,
140 xla::XlaOp var) {
141 if (alg == RNG_ALG_THREEFRY) {
142 static constexpr int kStateSize = 1;
143 auto state = BitcastConvertType(
144 xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64);
145 auto key = BitcastConvertType(
146 xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}),
147 xla::U64);
148 return std::make_pair(state, key);
149 } else {
150 static constexpr int kStateSize = 2;
151 auto state =
152 BitcastConvertType(xla::Slice(var, {0}, {kStateSize}, {1}), xla::U64);
153 auto key = xla::Reshape(
154 BitcastConvertType(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}),
155 xla::U64),
156 {});
157 return std::make_pair(state, key);
158 }
159 }
160
StateAndKeyToVariable(Algorithm alg,xla::XlaOp state,xla::XlaOp key)161 xla::XlaOp StateAndKeyToVariable(Algorithm alg, xla::XlaOp state,
162 xla::XlaOp key) {
163 auto builder = state.builder();
164 if (alg == RNG_ALG_THREEFRY) {
165 return ConcatScalars(builder, {state, key});
166 } else {
167 return ConcatInDim(builder, {state, xla::Reshape(key, {1})}, 0);
168 }
169 }
170
171 // A helper function containing the common part of several kernels below.
172 // Precondition: 'algorithm' and 'shape' are compile-time constants.
CompileImpl(XlaOpKernelContext * ctx,int state_input_idx,int alg_input_idx,int shape_input_idx,std::function<SamplerReturnType (Algorithm,xla::XlaOp,xla::XlaOp,TensorShape)> const & sampler)173 Status CompileImpl(
174 XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx,
175 int shape_input_idx,
176 std::function<SamplerReturnType(Algorithm, xla::XlaOp, xla::XlaOp,
177 TensorShape)> const& sampler) {
178 auto alg_shape = ctx->InputShape(alg_input_idx);
179 if (alg_shape.dims() != 0) {
180 return errors::InvalidArgument("algorithm must be of shape [], not ",
181 alg_shape.DebugString());
182 }
183 xla::Literal alg_literal;
184 TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
185 Algorithm alg = Algorithm(alg_literal.Get<int64_t>({}));
186 if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) {
187 return errors::InvalidArgument("Unsupported algorithm id: ", alg);
188 }
189
190 xla::XlaOp var;
191 TensorShape var_shape;
192 TF_RETURN_IF_ERROR(ctx->ReadVariableInput(
193 state_input_idx, STATE_ELEMENT_DTYPE, &var_shape, &var));
194 TF_RETURN_IF_ERROR(CheckStateShape(alg, var_shape));
195 TensorShape shape;
196 TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
197 xla::XlaOp state;
198 xla::XlaOp key;
199 std::tie(state, key) = StateAndKeyFromVariable(alg, var);
200 auto status_or_value = sampler(alg, state, key, shape);
201 if (!status_or_value.ok()) {
202 return status_or_value.status();
203 }
204 xla::RngOutput value_state = std::move(status_or_value).value();
205 state = value_state.state;
206 ctx->SetOutput(0, value_state.value);
207 var = StateAndKeyToVariable(alg, state, key);
208 xla::PrimitiveType state_element_type;
209 TF_RETURN_IF_ERROR(
210 DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
211 var = BitcastConvertType(var, state_element_type);
212 TF_RETURN_IF_ERROR(
213 ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
214 return OkStatus();
215 }
216
217 class StatefulUniformOp : public XlaOpKernel {
218 public:
StatefulUniformOp(OpKernelConstruction * ctx)219 explicit StatefulUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
220 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
221 }
222
Compile(XlaOpKernelContext * ctx)223 void Compile(XlaOpKernelContext* ctx) override {
224 xla::XlaBuilder* builder = ctx->builder();
225 auto sampler = [builder, this](Algorithm alg, xla::XlaOp state,
226 xla::XlaOp key,
227 TensorShape shape) -> SamplerReturnType {
228 xla::Shape xla_shape;
229 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
230 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
231 xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
232 xla::RngOutput uniform_state = StatefulRngUniform(
233 alg, key, state, xla_shape,
234 xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
235 xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
236 xla::XlaOp uniform = uniform_state.value;
237 state = uniform_state.state;
238 uniform = MaybeConvertF32ToBF16(uniform, dtype_);
239 return {{uniform, state}};
240 };
241 OP_REQUIRES_OK(ctx,
242 CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
243 /*shape_input_idx=*/2, sampler));
244 }
245
246 private:
247 DataType dtype_;
248
249 TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformOp);
250 };
251
252 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
253 REGISTER_XLA_OP(Name("StatefulUniform")
254 .CompileTimeConstantInput("algorithm")
255 .CompileTimeConstantInput("shape")
256 .TypeConstraint("dtype",
257 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
258 StatefulUniformOp);
259
260 class StatefulStandardNormalOp : public XlaOpKernel {
261 public:
StatefulStandardNormalOp(OpKernelConstruction * ctx)262 explicit StatefulStandardNormalOp(OpKernelConstruction* ctx)
263 : XlaOpKernel(ctx) {
264 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
265 }
266
Compile(XlaOpKernelContext * ctx)267 void Compile(XlaOpKernelContext* ctx) override {
268 auto sampler =
269 // Needs explicit lambda return type because it fails to be inferred.
270 [this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
271 TensorShape shape) -> SamplerReturnType {
272 xla::Shape xla_shape;
273 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
274 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
275 xla::RngOutput value_state = xla::NormalFloatingPointDistribution(
276 key, state, BitGen(alg), xla_shape);
277 xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_);
278 return {{normal, value_state.state}};
279 };
280 OP_REQUIRES_OK(ctx,
281 CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
282 /*shape_input_idx=*/2, sampler));
283 }
284
285 private:
286 DataType dtype_;
287
288 TF_DISALLOW_COPY_AND_ASSIGN(StatefulStandardNormalOp);
289 };
290
291 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
292 REGISTER_XLA_OP(Name("StatefulStandardNormalV2")
293 .CompileTimeConstantInput("algorithm")
294 .CompileTimeConstantInput("shape")
295 .TypeConstraint("dtype",
296 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
297 StatefulStandardNormalOp);
298
299 class StatefulTruncatedNormalOp : public XlaOpKernel {
300 public:
StatefulTruncatedNormalOp(OpKernelConstruction * ctx)301 explicit StatefulTruncatedNormalOp(OpKernelConstruction* ctx)
302 : XlaOpKernel(ctx) {
303 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
304 }
305
Compile(XlaOpKernelContext * ctx)306 void Compile(XlaOpKernelContext* ctx) override {
307 xla::XlaBuilder* builder = ctx->builder();
308 auto sampler =
309 // Needs explicit lambda return type because it fails to be inferred.
310 [builder, this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
311 TensorShape shape) -> SamplerReturnType {
312 xla::Shape xla_shape;
313 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
314 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
315
316 xla::RngOutput uniform_result = StatefulRngUniform(
317 alg, key, state, xla_shape,
318 xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
319 xla::One(builder, xla_shape.element_type()));
320 xla::XlaOp uniform = uniform_result.value;
321 state = uniform_result.state;
322 xla::XlaOp truncated_normal = TruncatedNormal(uniform);
323 truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
324 return {{truncated_normal, state}};
325 };
326 OP_REQUIRES_OK(ctx,
327 CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
328 /*shape_input_idx=*/2, sampler));
329 }
330
331 private:
332 DataType dtype_;
333
334 TF_DISALLOW_COPY_AND_ASSIGN(StatefulTruncatedNormalOp);
335 };
336
337 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
338 REGISTER_XLA_OP(Name("StatefulTruncatedNormal")
339 .CompileTimeConstantInput("algorithm")
340 .CompileTimeConstantInput("shape")
341 .TypeConstraint("dtype",
342 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
343 StatefulTruncatedNormalOp);
344
345 class StatefulUniformIntOp : public XlaOpKernel {
346 public:
StatefulUniformIntOp(OpKernelConstruction * ctx)347 explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
348 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
349 }
350
Compile(XlaOpKernelContext * ctx)351 void Compile(XlaOpKernelContext* ctx) override {
352 xla::XlaOp minval = ctx->Input(3);
353 xla::XlaOp maxval = ctx->Input(4);
354 auto sample_with_threefry =
355 [minval, maxval, this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
356 TensorShape shape) -> SamplerReturnType {
357 xla::Shape xla_shape;
358 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
359 return StatefulRngUniform(alg, key, state, xla_shape, minval, maxval);
360 };
361 OP_REQUIRES_OK(ctx,
362 CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
363 /*shape_input_idx=*/2, sample_with_threefry));
364 }
365
366 private:
367 DataType dtype_;
368
369 TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformIntOp);
370 };
371
372 REGISTER_XLA_OP(Name("StatefulUniformInt")
373 .CompileTimeConstantInput("algorithm")
374 .CompileTimeConstantInput("shape")
375 .TypeConstraint("dtype",
376 {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
377 StatefulUniformIntOp);
378
379 class StatefulUniformFullIntOp : public XlaOpKernel {
380 public:
StatefulUniformFullIntOp(OpKernelConstruction * ctx)381 explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx)
382 : XlaOpKernel(ctx) {
383 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
384 }
385
Compile(XlaOpKernelContext * ctx)386 void Compile(XlaOpKernelContext* ctx) override {
387 auto sample_with_threefry = [this](Algorithm alg, xla::XlaOp state,
388 xla::XlaOp key,
389 TensorShape shape) -> SamplerReturnType {
390 xla::Shape xla_shape;
391 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
392 return StatefulRngUniformFullInt(alg, key, state, xla_shape);
393 };
394 OP_REQUIRES_OK(ctx,
395 CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
396 /*shape_input_idx=*/2, sample_with_threefry));
397 }
398
399 private:
400 DataType dtype_;
401
402 TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformFullIntOp);
403 };
404
405 REGISTER_XLA_OP(Name("StatefulUniformFullInt")
406 .CompileTimeConstantInput("algorithm")
407 .CompileTimeConstantInput("shape")
408 .TypeConstraint("dtype",
409 {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
410 StatefulUniformFullIntOp);
411
IncreaseCounter(Algorithm const & alg,xla::XlaOp counter,xla::XlaOp delta)412 xla::XlaOp IncreaseCounter(Algorithm const& alg, xla::XlaOp counter,
413 xla::XlaOp delta) {
414 // Multiplying 256 to be consistent with the CPU/GPU kernels
415 delta = delta * ConstantR0WithType(delta.builder(), xla::U64, 256);
416 if (alg == RNG_ALG_PHILOX) {
417 return xla::PhiloxIncreaseCounter(counter, delta);
418 } else {
419 return counter + delta;
420 }
421 }
422
PadRight(xla::XlaOp a,int n)423 xla::XlaOp PadRight(xla::XlaOp a, int n) {
424 return xla::Pad(a, xla::ScalarLike(a, 0),
425 xla::MakeEdgePaddingConfig({{0, n}}));
426 }
427
428 template <typename AlgEnumType = int64_t, bool read_old_value = false>
429 class RngSkipOp : public XlaOpKernel {
430 public:
RngSkipOp(OpKernelConstruction * ctx)431 explicit RngSkipOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
432
Compile(XlaOpKernelContext * ctx)433 void Compile(XlaOpKernelContext* ctx) override {
434 const int state_input_idx = 0;
435 const int alg_input_idx = 1;
436 const int delta_input_idx = 2;
437 xla::XlaOp var;
438 TensorShape var_shape;
439 OP_REQUIRES_OK(ctx,
440 ctx->ReadVariableInput(state_input_idx, STATE_ELEMENT_DTYPE,
441 &var_shape, &var));
442 xla::Literal alg_literal;
443 OP_REQUIRES_OK(ctx, ctx->ConstantInput(alg_input_idx, &alg_literal));
444 Algorithm alg = Algorithm(alg_literal.Get<AlgEnumType>({}));
445 OP_REQUIRES(ctx, alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX,
446 errors::InvalidArgument("Unsupported algorithm id: ", alg));
447 OP_REQUIRES_OK(ctx, CheckStateShape(alg, var_shape));
448 if (read_old_value) {
449 auto counter_size = GetCounterSize(alg);
450 xla::XlaOp output = var;
451 if (RNG_MAX_COUNTER_SIZE > counter_size) {
452 // Because the size of `var` depends on the algorithm while we want the
453 // output to have a fixed size (to help shape inference), we fix the
454 // output size to be the maximal state size among algorithms, and right-
455 // pad it with zeros if var's size is smaller than that.
456 output = PadRight(output, RNG_MAX_COUNTER_SIZE - counter_size);
457 }
458 ctx->SetOutput(0, output);
459 }
460 xla::XlaOp counter;
461 xla::XlaOp key;
462 std::tie(counter, key) = StateAndKeyFromVariable(alg, var);
463 xla::XlaOp delta = ctx->Input(delta_input_idx);
464 delta = BitcastConvertType(delta, xla::U64);
465 auto new_counter = IncreaseCounter(alg, counter, delta);
466 var = StateAndKeyToVariable(alg, new_counter, key);
467 xla::PrimitiveType state_element_type;
468 OP_REQUIRES_OK(
469 ctx, DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
470 var = BitcastConvertType(var, state_element_type);
471 OP_REQUIRES_OK(
472 ctx, ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
473 }
474
475 private:
476 TF_DISALLOW_COPY_AND_ASSIGN(RngSkipOp);
477 };
478
479 REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"),
480 RngSkipOp<>);
481
482 using RngReadAndSkipOp = RngSkipOp<int32, true>;
483
484 REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"),
485 RngReadAndSkipOp);
486
487 } // namespace
488 } // namespace tensorflow
489