1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/stateless_random_ops_v2.h"
17
18 #include <cmath>
19
20 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
21 #include "tensorflow/compiler/tf2xla/lib/random.h"
22 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/lib/constants.h"
29 #include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
30 #include "tensorflow/compiler/xla/client/lib/math.h"
31 #include "tensorflow/compiler/xla/client/lib/prng.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/rng_alg.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_shape.h"
38 #include "tensorflow/core/lib/math/math_util.h"
39
40 namespace tensorflow {
41
42 namespace {
43
TensorFlowRngAlgToXla(Algorithm const & alg)44 inline xla::RandomAlgorithm TensorFlowRngAlgToXla(Algorithm const& alg) {
45 if (alg == RNG_ALG_PHILOX) {
46 return xla::RandomAlgorithm::RNG_PHILOX;
47 } else if (alg == RNG_ALG_THREEFRY) {
48 return xla::RandomAlgorithm::RNG_THREE_FRY;
49 } else if (alg == RNG_ALG_AUTO_SELECT) {
50 return xla::RandomAlgorithm::RNG_DEFAULT;
51 }
52 return xla::RandomAlgorithm::RNG_THREE_FRY;
53 }
54
XlaRngAlgToTensorFlow(xla::RandomAlgorithm const & alg)55 inline Algorithm XlaRngAlgToTensorFlow(xla::RandomAlgorithm const& alg) {
56 if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
57 return RNG_ALG_PHILOX;
58 } else if (alg == xla::RandomAlgorithm::RNG_THREE_FRY) {
59 return RNG_ALG_THREEFRY;
60 } else if (alg == xla::RandomAlgorithm::RNG_DEFAULT) {
61 return RNG_ALG_AUTO_SELECT;
62 }
63 return RNG_ALG_THREEFRY;
64 }
65
GetCounter(xla::RandomAlgorithm const & alg,xla::XlaOp state)66 xla::XlaOp GetCounter(xla::RandomAlgorithm const& alg, xla::XlaOp state) {
67 Algorithm alg_ = XlaRngAlgToTensorFlow(alg);
68 return xla::Slice(state, {RNG_KEY_SIZE},
69 {RNG_KEY_SIZE + GetCounterSize(alg_)}, {1});
70 }
71
BitGenerator(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape)72 xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key,
73 xla::XlaOp counter, const xla::Shape& shape) {
74 key = BitcastConvertType(key, xla::U64);
75 counter = BitcastConvertType(counter, xla::U64);
76 xla::XlaOp state = xla::ConcatInDim(key.builder(), {key, counter}, 0);
77 xla::XlaOp result = xla::RngBitGenerator(alg, state, shape);
78 auto new_counter = GetCounter(alg, xla::GetTupleElement(result, 0));
79 new_counter = BitcastConvertType(new_counter, xla::S64);
80 return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
81 /*state=*/new_counter};
82 }
83
GetKeyCounter(absl::string_view device_type_string,xla::XlaOp key)84 std::tuple<xla::XlaOp, xla::XlaOp> GetKeyCounter(
85 absl::string_view device_type_string, xla::XlaOp key) {
86 // The Philox algorithm may cause performance regression on other devices.
87 // Turn on the Philox algorithm for the CPU and GPU backends only.
88 if (device_type_string == DEVICE_GPU_XLA_JIT ||
89 device_type_string == DEVICE_CPU_XLA_JIT) {
90 auto counter_key = xla::ScramblePhiloxKey(key);
91 return std::make_tuple(counter_key.second, counter_key.first);
92 } else {
93 auto counter_shape =
94 xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
95 auto counter = xla::Zeros(key.builder(), counter_shape);
96 return std::make_tuple(key, counter);
97 }
98 }
99
DefaultRngAlgForDeviceType(absl::string_view device_type_string)100 Algorithm DefaultRngAlgForDeviceType(absl::string_view device_type_string) {
101 // The Philox algorithm may cause performance regression on other devices.
102 // Turn on the Philox algorithm for the CPU and GPU backends only.
103 if (device_type_string == DEVICE_GPU_XLA_JIT ||
104 device_type_string == DEVICE_CPU_XLA_JIT) {
105 return RNG_ALG_PHILOX;
106 } else {
107 return RNG_ALG_AUTO_SELECT;
108 }
109 }
110
111 } // namespace
112
StatelessRngUniformV2(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)113 xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg,
114 xla::XlaOp key, xla::XlaOp counter,
115 const xla::Shape& shape, xla::XlaOp minval,
116 xla::XlaOp maxval) {
117 xla::XlaBuilder* builder = key.builder();
118 xla::PrimitiveType type = shape.element_type();
119 using std::placeholders::_1;
120 using std::placeholders::_2;
121 using std::placeholders::_3;
122 auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
123 switch (type) {
124 case xla::F32:
125 case xla::F64:
126 return xla::UniformFloatingPointDistribution(key, counter, generator,
127 minval, maxval, shape);
128 case xla::S32:
129 case xla::S64:
130 case xla::U32:
131 case xla::U64:
132 return UniformIntDistribution(key, counter, generator, minval, maxval,
133 shape);
134 break;
135 default:
136 return {builder->ReportError(xla::Unimplemented(
137 "Types other than F32, S32, S64, U32 and U64 are not "
138 "implemented by "
139 "StatelessRngUniformV2; got %s",
140 xla::primitive_util::LowercasePrimitiveTypeName(type))),
141 counter};
142 }
143 }
144
145 namespace {
146
StatelessRngUniformFullInt(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape)147 xla::RngOutput StatelessRngUniformFullInt(xla::RandomAlgorithm const& alg,
148 xla::XlaOp key, xla::XlaOp counter,
149 const xla::Shape& shape) {
150 xla::XlaBuilder* builder = key.builder();
151
152 xla::PrimitiveType type = shape.element_type();
153 xla::RngOutput output = BitGenerator(alg, key, counter, shape);
154 switch (type) {
155 case xla::U32:
156 case xla::U64:
157 return output;
158 case xla::S32:
159 case xla::S64:
160 return xla::RngOutput{BitcastConvertType(output.value, type),
161 output.state};
162 default:
163 return {
164 builder->ReportError(xla::Unimplemented(
165 "Types other than U32, S32, U64 and S64 are not implemented by "
166 "StatelessRngUniformFullInt; got: %s",
167 xla::primitive_util::LowercasePrimitiveTypeName(type))),
168 output.state};
169 }
170 }
171
AlgorithmFromInput(XlaOpKernelContext * ctx,int alg_input_idx,absl::string_view device_type_string,xla::RandomAlgorithm * xla_alg)172 Status AlgorithmFromInput(XlaOpKernelContext* ctx, int alg_input_idx,
173 absl::string_view device_type_string,
174 xla::RandomAlgorithm* xla_alg) {
175 auto alg_shape = ctx->InputShape(alg_input_idx);
176 if (alg_shape.dims() != 0) {
177 return errors::InvalidArgument("algorithm must be of shape [], not ",
178 alg_shape.DebugString());
179 }
180 xla::Literal alg_literal;
181 TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
182 auto alg = Algorithm(alg_literal.Get<int>({}));
183 if (alg == RNG_ALG_AUTO_SELECT) {
184 alg = DefaultRngAlgForDeviceType(device_type_string);
185 }
186 *xla_alg = TensorFlowRngAlgToXla(alg);
187 return OkStatus();
188 }
189
MaybeSliceCounter(xla::RandomAlgorithm const & alg,TensorShape const & counter_shape,xla::XlaOp counter)190 xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg,
191 TensorShape const& counter_shape,
192 xla::XlaOp counter) {
193 auto input_counter_size = counter_shape.dim_size(0);
194 // TODO(wangpeng): We shouldn't rely on
195 // `GetCounterSize(XlaRngAlgToTensorFlow(alg))` to decide the size when
196 // alg==RNG_DEFAULT, which happens to give the correct answer. We should
197 // define a "get counter size" function for xla::RandomAlgorithm, independent
198 // of `GetCounterSize`.
199 auto real_counter_size = GetCounterSize(XlaRngAlgToTensorFlow(alg));
200 if (input_counter_size > real_counter_size) {
201 counter = xla::Slice(counter, {0}, {real_counter_size}, {1});
202 }
203 return counter;
204 }
205
206 class StatelessRandomUniformOp : public XlaOpKernel {
207 public:
StatelessRandomUniformOp(OpKernelConstruction * ctx)208 explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
209 : XlaOpKernel(ctx),
210 device_type_string_(ctx->device_type().type_string()) {
211 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
212 }
213
Compile(XlaOpKernelContext * ctx)214 void Compile(XlaOpKernelContext* ctx) override {
215 xla::XlaBuilder* builder = ctx->builder();
216
217 TensorShape shape;
218 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(
219 0, &shape, xla::ValueInferenceMode::kUpperBound));
220
221 const int key_input_idx = 1;
222 const int counter_input_idx = 2;
223 const int alg_input_idx = 3;
224 xla::XlaOp key = ctx->Input(key_input_idx);
225 xla::XlaOp counter = ctx->Input(counter_input_idx);
226
227 xla::RandomAlgorithm alg;
228 OP_REQUIRES_OK(
229 ctx, AlgorithmFromInput(ctx, alg_input_idx, device_type_string_, &alg));
230
231 auto counter_shape = ctx->InputShape(counter_input_idx);
232 OP_REQUIRES_OK(ctx, CheckKeyCounterShape(XlaRngAlgToTensorFlow(alg),
233 ctx->InputShape(key_input_idx),
234 counter_shape));
235
236 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
237 xla::Shape xla_shape;
238 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
239 xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
240
241 counter = MaybeSliceCounter(alg, counter_shape, counter);
242
243 auto result = StatelessRngUniformV2(
244 alg, key, counter, xla_shape,
245 xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
246 xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
247 auto uniform = MaybeConvertF32ToBF16(result.value, dtype_);
248
249 // If the input shape is constant, no need to set dimension sizes.
250 // TODO(hinsu): Simplify this once MLIR bridge can handle bounded types.
251 TensorShape static_shape;
252 Status status = ctx->ConstantInputAsShape(0, &static_shape);
253 if (status.ok()) {
254 ctx->SetOutput(0, uniform);
255 return;
256 }
257
258 auto result_or = xla::SetAllDimensionSizes(&ctx->value_inference(), uniform,
259 ctx->Input(0));
260 OP_REQUIRES_OK(ctx, result_or.status());
261 ctx->SetOutput(0, result_or.ValueOrDie());
262 }
263
264 private:
265 DataType dtype_;
266 string device_type_string_;
267
268 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
269 };
270
271 REGISTER_XLA_OP(Name("StatelessRandomUniformV2")
272 .CompileTimeConstantInput("shape")
273 .CompileTimeConstantInput("alg")
274 .TypeConstraint("dtype",
275 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
276 StatelessRandomUniformOp);
277
278 class StatelessRandomUniformIntOp : public XlaOpKernel {
279 public:
StatelessRandomUniformIntOp(OpKernelConstruction * ctx)280 explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
281 : XlaOpKernel(ctx),
282 device_type_string_(ctx->device_type().type_string()) {
283 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
284 }
285
Compile(XlaOpKernelContext * ctx)286 void Compile(XlaOpKernelContext* ctx) override {
287 TensorShape shape;
288 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
289
290 const int key_input_idx = 1;
291 const int counter_input_idx = 2;
292 const int alg_input_idx = 3;
293 xla::XlaOp key = ctx->Input(key_input_idx);
294 xla::XlaOp counter = ctx->Input(counter_input_idx);
295
296 xla::RandomAlgorithm alg;
297 OP_REQUIRES_OK(
298 ctx, AlgorithmFromInput(ctx, alg_input_idx, device_type_string_, &alg));
299
300 auto counter_shape = ctx->InputShape(counter_input_idx);
301 OP_REQUIRES_OK(ctx, CheckKeyCounterShape(XlaRngAlgToTensorFlow(alg),
302 ctx->InputShape(key_input_idx),
303 counter_shape));
304
305 const int minval_input_idx = 4;
306 const int maxval_input_idx = 5;
307 TensorShape minval_shape = ctx->InputShape(minval_input_idx);
308 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
309 errors::InvalidArgument("minval must be scalar, got shape ",
310 minval_shape.DebugString()));
311 TensorShape maxval_shape = ctx->InputShape(maxval_input_idx);
312 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
313 errors::InvalidArgument("maxval must be scalar, got shape ",
314 maxval_shape.DebugString()));
315
316 xla::XlaOp minval = ctx->Input(minval_input_idx);
317 xla::XlaOp maxval = ctx->Input(maxval_input_idx);
318
319 xla::Shape xla_shape;
320 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
321
322 counter = MaybeSliceCounter(alg, counter_shape, counter);
323 auto result =
324 StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval);
325 ctx->SetOutput(0, result.value);
326 }
327
328 private:
329 DataType dtype_;
330 string device_type_string_;
331
332 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
333 };
334
335 REGISTER_XLA_OP(Name("StatelessRandomUniformIntV2")
336 .CompileTimeConstantInput("shape")
337 .CompileTimeConstantInput("alg")
338 .TypeConstraint("dtype",
339 {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
340 StatelessRandomUniformIntOp);
341
342 class StatelessRandomUniformFullIntOp : public XlaOpKernel {
343 public:
StatelessRandomUniformFullIntOp(OpKernelConstruction * ctx)344 explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
345 : XlaOpKernel(ctx),
346 device_type_string_(ctx->device_type().type_string()) {
347 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
348 }
349
Compile(XlaOpKernelContext * ctx)350 void Compile(XlaOpKernelContext* ctx) override {
351 TensorShape shape;
352 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
353
354 const int key_input_idx = 1;
355 const int counter_input_idx = 2;
356 const int alg_input_idx = 3;
357 xla::XlaOp key = ctx->Input(key_input_idx);
358 xla::XlaOp counter = ctx->Input(counter_input_idx);
359
360 xla::RandomAlgorithm alg;
361 OP_REQUIRES_OK(
362 ctx, AlgorithmFromInput(ctx, alg_input_idx, device_type_string_, &alg));
363
364 auto counter_shape = ctx->InputShape(counter_input_idx);
365 OP_REQUIRES_OK(ctx, CheckKeyCounterShape(XlaRngAlgToTensorFlow(alg),
366 ctx->InputShape(key_input_idx),
367 counter_shape));
368
369 xla::Shape xla_shape;
370 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
371
372 counter = MaybeSliceCounter(alg, counter_shape, counter);
373 auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape);
374 ctx->SetOutput(0, result.value);
375 }
376
377 private:
378 DataType dtype_;
379 string device_type_string_;
380
381 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
382 };
383
384 REGISTER_XLA_OP(Name("StatelessRandomUniformFullIntV2")
385 .CompileTimeConstantInput("shape")
386 .CompileTimeConstantInput("alg")
387 .TypeConstraint("dtype",
388 {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
389 StatelessRandomUniformFullIntOp);
390
391 class StatelessRandomNormalOp : public XlaOpKernel {
392 public:
StatelessRandomNormalOp(OpKernelConstruction * ctx)393 explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
394 : XlaOpKernel(ctx),
395 device_type_string_(ctx->device_type().type_string()) {
396 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
397 }
398
Compile(XlaOpKernelContext * ctx)399 void Compile(XlaOpKernelContext* ctx) override {
400 TensorShape shape;
401 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(
402 0, &shape, xla::ValueInferenceMode::kUpperBound));
403
404 const int key_input_idx = 1;
405 const int counter_input_idx = 2;
406 const int alg_input_idx = 3;
407 xla::XlaOp key = ctx->Input(key_input_idx);
408 xla::XlaOp counter = ctx->Input(counter_input_idx);
409
410 xla::RandomAlgorithm alg;
411 OP_REQUIRES_OK(
412 ctx, AlgorithmFromInput(ctx, alg_input_idx, device_type_string_, &alg));
413
414 auto counter_shape = ctx->InputShape(counter_input_idx);
415 OP_REQUIRES_OK(ctx, CheckKeyCounterShape(XlaRngAlgToTensorFlow(alg),
416 ctx->InputShape(key_input_idx),
417 counter_shape));
418
419 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
420
421 xla::Shape xla_shape;
422 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
423
424 using std::placeholders::_1;
425 using std::placeholders::_2;
426 using std::placeholders::_3;
427 auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
428 counter = MaybeSliceCounter(alg, counter_shape, counter);
429 auto result = xla::NormalFloatingPointDistribution(key, counter, generator,
430 xla_shape);
431 auto normal = MaybeConvertF32ToBF16(result.value, dtype_);
432
433 // If the input shape is constant, no need to set dimension sizes.
434 // TODO(hinsu): Simplify this once MLIR bridge can handle bounded types.
435 TensorShape static_shape;
436 Status status = ctx->ConstantInputAsShape(0, &static_shape);
437 if (status.ok()) {
438 ctx->SetOutput(0, normal);
439 return;
440 }
441
442 auto result_or = xla::SetAllDimensionSizes(&ctx->value_inference(), normal,
443 ctx->Input(0));
444 OP_REQUIRES_OK(ctx, result_or.status());
445 ctx->SetOutput(0, result_or.ValueOrDie());
446 }
447
448 private:
449 DataType dtype_;
450 string device_type_string_;
451
452 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
453 };
454
455 REGISTER_XLA_OP(Name("StatelessRandomNormalV2")
456 .CompileTimeConstantInput("shape")
457 .CompileTimeConstantInput("alg")
458 .TypeConstraint("dtype",
459 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
460 StatelessRandomNormalOp);
461
462 class StatelessTruncatedNormalOp : public XlaOpKernel {
463 public:
StatelessTruncatedNormalOp(OpKernelConstruction * ctx)464 explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
465 : XlaOpKernel(ctx),
466 device_type_string_(ctx->device_type().type_string()) {
467 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
468 }
469
Compile(XlaOpKernelContext * ctx)470 void Compile(XlaOpKernelContext* ctx) override {
471 TensorShape shape;
472 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
473
474 const int key_input_idx = 1;
475 const int counter_input_idx = 2;
476 const int alg_input_idx = 3;
477 xla::XlaOp key = ctx->Input(key_input_idx);
478 xla::XlaOp counter = ctx->Input(counter_input_idx);
479
480 xla::RandomAlgorithm alg;
481 OP_REQUIRES_OK(
482 ctx, AlgorithmFromInput(ctx, alg_input_idx, device_type_string_, &alg));
483
484 auto counter_shape = ctx->InputShape(counter_input_idx);
485 OP_REQUIRES_OK(ctx, CheckKeyCounterShape(XlaRngAlgToTensorFlow(alg),
486 ctx->InputShape(key_input_idx),
487 counter_shape));
488
489 xla::XlaBuilder* builder = ctx->builder();
490
491 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
492 xla::Shape xla_shape;
493 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
494
495 counter = MaybeSliceCounter(alg, counter_shape, counter);
496 auto result = StatelessRngUniformV2(
497 alg, key, counter, xla_shape,
498 xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
499 xla::One(builder, xla_shape.element_type()));
500 xla::XlaOp truncated_normal = TruncatedNormal(result.value);
501 truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
502 ctx->SetOutput(0, truncated_normal);
503 }
504
505 private:
506 DataType dtype_;
507 string device_type_string_;
508
509 TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
510 };
511
512 REGISTER_XLA_OP(Name("StatelessTruncatedNormalV2")
513 .CompileTimeConstantInput("shape")
514 .CompileTimeConstantInput("alg")
515 .TypeConstraint("dtype",
516 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
517 StatelessTruncatedNormalOp);
518
519 class GetKeyCounterAlgOp : public XlaOpKernel {
520 public:
GetKeyCounterAlgOp(OpKernelConstruction * ctx)521 explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx)
522 : XlaOpKernel(ctx),
523 device_type_string_(ctx->device_type().type_string()) {}
524
Compile(XlaOpKernelContext * ctx)525 void Compile(XlaOpKernelContext* ctx) override {
526 TensorShape seed_shape = ctx->InputShape(0);
527 OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
528 errors::InvalidArgument("seed must have shape [2], not ",
529 seed_shape.DebugString()));
530 xla::XlaOp seed = ctx->Input(0);
531
532 xla::XlaBuilder* builder = seed.builder();
533 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
534 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
535 xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
536 auto key_counter = GetKeyCounter(device_type_string_, key);
537 key = std::get<0>(key_counter);
538 auto counter = std::get<1>(key_counter);
539 auto alg = DefaultRngAlgForDeviceType(device_type_string_);
540 key = xla::Reshape(key, {RNG_KEY_SIZE});
541 ctx->SetOutput(0, key);
542 ctx->SetOutput(1, counter);
543 ctx->SetOutput(2, ConstantR0(builder, static_cast<int>(alg)));
544 }
545
546 private:
547 string device_type_string_;
548
549 TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp);
550 };
551
552 // TODO(hinsu): Dis-allow unsupported int64 seed types.
553 REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
554
555 class GetKeyCounterOp : public XlaOpKernel {
556 public:
GetKeyCounterOp(OpKernelConstruction * ctx)557 explicit GetKeyCounterOp(OpKernelConstruction* ctx)
558 : XlaOpKernel(ctx),
559 device_type_string_(ctx->device_type().type_string()) {}
560
Compile(XlaOpKernelContext * ctx)561 void Compile(XlaOpKernelContext* ctx) override {
562 TensorShape seed_shape = ctx->InputShape(0);
563 OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
564 errors::InvalidArgument("seed must have shape [2], not ",
565 seed_shape.DebugString()));
566 xla::XlaOp seed = ctx->Input(0);
567
568 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
569 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
570 xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
571 auto key_counter = GetKeyCounter(device_type_string_, key);
572 key = std::get<0>(key_counter);
573 auto counter = std::get<1>(key_counter);
574 key = xla::Reshape(key, {RNG_KEY_SIZE});
575 ctx->SetOutput(0, key);
576 ctx->SetOutput(1, counter);
577 }
578
579 private:
580 string device_type_string_;
581
582 TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterOp);
583 };
584
585 // TODO(hinsu): Dis-allow unsupported int64 seed types.
586 REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounter"), GetKeyCounterOp);
587
588 class GetAlgOp : public XlaOpKernel {
589 public:
GetAlgOp(OpKernelConstruction * ctx)590 explicit GetAlgOp(OpKernelConstruction* ctx)
591 : XlaOpKernel(ctx),
592 device_type_string_(ctx->device_type().type_string()) {}
593
Compile(XlaOpKernelContext * ctx)594 void Compile(XlaOpKernelContext* ctx) override {
595 auto alg = DefaultRngAlgForDeviceType(device_type_string_);
596 auto builder = ctx->builder();
597 ctx->SetOutput(0, ConstantR0(builder, static_cast<int>(alg)));
598 }
599
600 private:
601 string device_type_string_;
602
603 TF_DISALLOW_COPY_AND_ASSIGN(GetAlgOp);
604 };
605
606 REGISTER_XLA_OP(Name("StatelessRandomGetAlg"), GetAlgOp);
607
608 REGISTER_XLA_OP(Name("XlaRngBitGenerator")
609 .CompileTimeConstantInput("algorithm")
610 .CompileTimeConstantInput("shape")
611 .TypeConstraint("dtype", {DT_UINT32, DT_UINT64}),
612 MlirXlaOpKernel);
613
614 } // namespace
615 } // namespace tensorflow
616