• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/stateless_random_ops_v2.h"
17 
18 #include <cmath>
19 
20 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
21 #include "tensorflow/compiler/tf2xla/lib/random.h"
22 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/lib/constants.h"
29 #include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
30 #include "tensorflow/compiler/xla/client/lib/math.h"
31 #include "tensorflow/compiler/xla/client/lib/prng.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/rng_alg.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_shape.h"
38 #include "tensorflow/core/lib/math/math_util.h"
39 
40 namespace tensorflow {
41 
42 namespace {
43 
TensorFlowRngAlgToXla(Algorithm const & alg)44 inline xla::RandomAlgorithm TensorFlowRngAlgToXla(Algorithm const& alg) {
45   if (alg == RNG_ALG_PHILOX) {
46     return xla::RandomAlgorithm::RNG_PHILOX;
47   } else if (alg == RNG_ALG_THREEFRY) {
48     return xla::RandomAlgorithm::RNG_THREE_FRY;
49   } else if (alg == RNG_ALG_AUTO_SELECT) {
50     return xla::RandomAlgorithm::RNG_DEFAULT;
51   }
52   return xla::RandomAlgorithm::RNG_THREE_FRY;
53 }
54 
XlaRngAlgToTensorFlow(xla::RandomAlgorithm const & alg)55 inline Algorithm XlaRngAlgToTensorFlow(xla::RandomAlgorithm const& alg) {
56   if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
57     return RNG_ALG_PHILOX;
58   } else if (alg == xla::RandomAlgorithm::RNG_THREE_FRY) {
59     return RNG_ALG_THREEFRY;
60   } else if (alg == xla::RandomAlgorithm::RNG_DEFAULT) {
61     return RNG_ALG_AUTO_SELECT;
62   }
63   return RNG_ALG_THREEFRY;
64 }
65 
GetCounter(xla::RandomAlgorithm const & alg,xla::XlaOp state)66 xla::XlaOp GetCounter(xla::RandomAlgorithm const& alg, xla::XlaOp state) {
67   Algorithm alg_ = XlaRngAlgToTensorFlow(alg);
68   return xla::Slice(state, {RNG_KEY_SIZE},
69                     {RNG_KEY_SIZE + GetCounterSize(alg_)}, {1});
70 }
71 
BitGenerator(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape)72 xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key,
73                             xla::XlaOp counter, const xla::Shape& shape) {
74   key = BitcastConvertType(key, xla::U64);
75   counter = BitcastConvertType(counter, xla::U64);
76   xla::XlaOp state = xla::ConcatInDim(key.builder(), {key, counter}, 0);
77   xla::XlaOp result = xla::RngBitGenerator(alg, state, shape);
78   auto new_counter = GetCounter(alg, xla::GetTupleElement(result, 0));
79   new_counter = BitcastConvertType(new_counter, xla::S64);
80   return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
81                         /*state=*/new_counter};
82 }
83 
GetKeyCounter(absl::string_view device_type_string,xla::XlaOp key)84 std::tuple<xla::XlaOp, xla::XlaOp> GetKeyCounter(
85     absl::string_view device_type_string, xla::XlaOp key) {
86   // The Philox algorithm may cause performance regression on other devices.
87   // Turn on the Philox algorithm for the CPU and GPU backends only.
88   if (device_type_string == DEVICE_GPU_XLA_JIT ||
89       device_type_string == DEVICE_CPU_XLA_JIT) {
90     auto counter_key = xla::ScramblePhiloxKey(key);
91     return std::make_tuple(counter_key.second, counter_key.first);
92   } else {
93     auto counter_shape =
94         xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
95     auto counter = xla::Zeros(key.builder(), counter_shape);
96     return std::make_tuple(key, counter);
97   }
98 }
99 
DefaultRngAlgForDeviceType(absl::string_view device_type_string)100 Algorithm DefaultRngAlgForDeviceType(absl::string_view device_type_string) {
101   // The Philox algorithm may cause performance regression on other devices.
102   // Turn on the Philox algorithm for the CPU and GPU backends only.
103   if (device_type_string == DEVICE_GPU_XLA_JIT ||
104       device_type_string == DEVICE_CPU_XLA_JIT) {
105     return RNG_ALG_PHILOX;
106   } else {
107     return RNG_ALG_AUTO_SELECT;
108   }
109 }
110 
111 }  // namespace
112 
StatelessRngUniformV2(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)113 xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg,
114                                    xla::XlaOp key, xla::XlaOp counter,
115                                    const xla::Shape& shape, xla::XlaOp minval,
116                                    xla::XlaOp maxval) {
117   xla::XlaBuilder* builder = key.builder();
118   xla::PrimitiveType type = shape.element_type();
119   using std::placeholders::_1;
120   using std::placeholders::_2;
121   using std::placeholders::_3;
122   auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
123   switch (type) {
124     case xla::F32:
125     case xla::F64:
126       return xla::UniformFloatingPointDistribution(key, counter, generator,
127                                                    minval, maxval, shape);
128     case xla::S32:
129     case xla::S64:
130     case xla::U32:
131     case xla::U64:
132       return UniformIntDistribution(key, counter, generator, minval, maxval,
133                                     shape);
134       break;
135     default:
136       return {builder->ReportError(xla::Unimplemented(
137                   "Types other than F32, S32, S64, U32 and U64 are not "
138                   "implemented by "
139                   "StatelessRngUniformV2; got %s",
140                   xla::primitive_util::LowercasePrimitiveTypeName(type))),
141               counter};
142   }
143 }
144 
145 namespace {
146 
StatelessRngUniformFullInt(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape)147 xla::RngOutput StatelessRngUniformFullInt(xla::RandomAlgorithm const& alg,
148                                           xla::XlaOp key, xla::XlaOp counter,
149                                           const xla::Shape& shape) {
150   xla::XlaBuilder* builder = key.builder();
151 
152   xla::PrimitiveType type = shape.element_type();
153   xla::RngOutput output = BitGenerator(alg, key, counter, shape);
154   switch (type) {
155     case xla::U32:
156     case xla::U64:
157       return output;
158     case xla::S32:
159     case xla::S64:
160       return xla::RngOutput{BitcastConvertType(output.value, type),
161                             output.state};
162     default:
163       return {
164           builder->ReportError(xla::Unimplemented(
165               "Types other than U32, S32, U64 and S64 are not implemented by "
166               "StatelessRngUniformFullInt; got: %s",
167               xla::primitive_util::LowercasePrimitiveTypeName(type))),
168           output.state};
169   }
170 }
171 
AlgorithmFromInput(XlaOpKernelContext * ctx,int alg_input_idx,absl::string_view device_type_string,xla::RandomAlgorithm * xla_alg)172 Status AlgorithmFromInput(XlaOpKernelContext* ctx, int alg_input_idx,
173                           absl::string_view device_type_string,
174                           xla::RandomAlgorithm* xla_alg) {
175   auto alg_shape = ctx->InputShape(alg_input_idx);
176   if (alg_shape.dims() != 0) {
177     return errors::InvalidArgument("algorithm must be of shape [], not ",
178                                    alg_shape.DebugString());
179   }
180   xla::Literal alg_literal;
181   TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
182   auto alg = Algorithm(alg_literal.Get<int>({}));
183   if (alg == RNG_ALG_AUTO_SELECT) {
184     alg = DefaultRngAlgForDeviceType(device_type_string);
185   }
186   *xla_alg = TensorFlowRngAlgToXla(alg);
187   return OkStatus();
188 }
189 
MaybeSliceCounter(xla::RandomAlgorithm const & alg,TensorShape const & counter_shape,xla::XlaOp counter)190 xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg,
191                              TensorShape const& counter_shape,
192                              xla::XlaOp counter) {
193   auto input_counter_size = counter_shape.dim_size(0);
194   // TODO(wangpeng): We shouldn't rely on
195   // `GetCounterSize(XlaRngAlgToTensorFlow(alg))` to decide the size when
196   // alg==RNG_DEFAULT, which happens to give the correct answer. We should
197   // define a "get counter size" function for xla::RandomAlgorithm, independent
198   // of `GetCounterSize`.
199   auto real_counter_size = GetCounterSize(XlaRngAlgToTensorFlow(alg));
200   if (input_counter_size > real_counter_size) {
201     counter = xla::Slice(counter, {0}, {real_counter_size}, {1});
202   }
203   return counter;
204 }
205 
206 class StatelessRandomUniformOp : public XlaOpKernel {
207  public:
StatelessRandomUniformOp(OpKernelConstruction * ctx)208   explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
209       : XlaOpKernel(ctx),
210         device_type_string_(ctx->device_type().type_string()) {
211     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
212   }
213 
Compile(XlaOpKernelContext * ctx)214   void Compile(XlaOpKernelContext* ctx) override {
215     xla::XlaBuilder* builder = ctx->builder();
216 
217     TensorShape shape;
218     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(
219                             0, &shape, xla::ValueInferenceMode::kUpperBound));
220 
221     const int key_input_idx = 1;
222     const int counter_input_idx = 2;
223     const int alg_input_idx = 3;
224     xla::XlaOp key = ctx->Input(key_input_idx);
225     xla::XlaOp counter = ctx->Input(counter_input_idx);
226 
227     xla::RandomAlgorithm alg;
228     OP_REQUIRES_OK(
229         ctx, AlgorithmFromInput(ctx, alg_input_idx, device_type_string_, &alg));
230 
231     auto counter_shape = ctx->InputShape(counter_input_idx);
232     OP_REQUIRES_OK(ctx, CheckKeyCounterShape(XlaRngAlgToTensorFlow(alg),
233                                              ctx->InputShape(key_input_idx),
234                                              counter_shape));
235 
236     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
237     xla::Shape xla_shape;
238     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
239     xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
240 
241     counter = MaybeSliceCounter(alg, counter_shape, counter);
242 
243     auto result = StatelessRngUniformV2(
244         alg, key, counter, xla_shape,
245         xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
246         xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
247     auto uniform = MaybeConvertF32ToBF16(result.value, dtype_);
248 
249     // If the input shape is constant, no need to set dimension sizes.
250     // TODO(hinsu): Simplify this once MLIR bridge can handle bounded types.
251     TensorShape static_shape;
252     Status status = ctx->ConstantInputAsShape(0, &static_shape);
253     if (status.ok()) {
254       ctx->SetOutput(0, uniform);
255       return;
256     }
257 
258     auto result_or = xla::SetAllDimensionSizes(&ctx->value_inference(), uniform,
259                                                ctx->Input(0));
260     OP_REQUIRES_OK(ctx, result_or.status());
261     ctx->SetOutput(0, result_or.ValueOrDie());
262   }
263 
264  private:
265   DataType dtype_;
266   string device_type_string_;
267 
268   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
269 };
270 
271 REGISTER_XLA_OP(Name("StatelessRandomUniformV2")
272                     .CompileTimeConstantInput("shape")
273                     .CompileTimeConstantInput("alg")
274                     .TypeConstraint("dtype",
275                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
276                 StatelessRandomUniformOp);
277 
278 class StatelessRandomUniformIntOp : public XlaOpKernel {
279  public:
StatelessRandomUniformIntOp(OpKernelConstruction * ctx)280   explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
281       : XlaOpKernel(ctx),
282         device_type_string_(ctx->device_type().type_string()) {
283     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
284   }
285 
Compile(XlaOpKernelContext * ctx)286   void Compile(XlaOpKernelContext* ctx) override {
287     TensorShape shape;
288     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
289 
290     const int key_input_idx = 1;
291     const int counter_input_idx = 2;
292     const int alg_input_idx = 3;
293     xla::XlaOp key = ctx->Input(key_input_idx);
294     xla::XlaOp counter = ctx->Input(counter_input_idx);
295 
296     xla::RandomAlgorithm alg;
297     OP_REQUIRES_OK(
298         ctx, AlgorithmFromInput(ctx, alg_input_idx, device_type_string_, &alg));
299 
300     auto counter_shape = ctx->InputShape(counter_input_idx);
301     OP_REQUIRES_OK(ctx, CheckKeyCounterShape(XlaRngAlgToTensorFlow(alg),
302                                              ctx->InputShape(key_input_idx),
303                                              counter_shape));
304 
305     const int minval_input_idx = 4;
306     const int maxval_input_idx = 5;
307     TensorShape minval_shape = ctx->InputShape(minval_input_idx);
308     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
309                 errors::InvalidArgument("minval must be scalar, got shape ",
310                                         minval_shape.DebugString()));
311     TensorShape maxval_shape = ctx->InputShape(maxval_input_idx);
312     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
313                 errors::InvalidArgument("maxval must be scalar, got shape ",
314                                         maxval_shape.DebugString()));
315 
316     xla::XlaOp minval = ctx->Input(minval_input_idx);
317     xla::XlaOp maxval = ctx->Input(maxval_input_idx);
318 
319     xla::Shape xla_shape;
320     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
321 
322     counter = MaybeSliceCounter(alg, counter_shape, counter);
323     auto result =
324         StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval);
325     ctx->SetOutput(0, result.value);
326   }
327 
328  private:
329   DataType dtype_;
330   string device_type_string_;
331 
332   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
333 };
334 
335 REGISTER_XLA_OP(Name("StatelessRandomUniformIntV2")
336                     .CompileTimeConstantInput("shape")
337                     .CompileTimeConstantInput("alg")
338                     .TypeConstraint("dtype",
339                                     {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
340                 StatelessRandomUniformIntOp);
341 
342 class StatelessRandomUniformFullIntOp : public XlaOpKernel {
343  public:
StatelessRandomUniformFullIntOp(OpKernelConstruction * ctx)344   explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
345       : XlaOpKernel(ctx),
346         device_type_string_(ctx->device_type().type_string()) {
347     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
348   }
349 
Compile(XlaOpKernelContext * ctx)350   void Compile(XlaOpKernelContext* ctx) override {
351     TensorShape shape;
352     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
353 
354     const int key_input_idx = 1;
355     const int counter_input_idx = 2;
356     const int alg_input_idx = 3;
357     xla::XlaOp key = ctx->Input(key_input_idx);
358     xla::XlaOp counter = ctx->Input(counter_input_idx);
359 
360     xla::RandomAlgorithm alg;
361     OP_REQUIRES_OK(
362         ctx, AlgorithmFromInput(ctx, alg_input_idx, device_type_string_, &alg));
363 
364     auto counter_shape = ctx->InputShape(counter_input_idx);
365     OP_REQUIRES_OK(ctx, CheckKeyCounterShape(XlaRngAlgToTensorFlow(alg),
366                                              ctx->InputShape(key_input_idx),
367                                              counter_shape));
368 
369     xla::Shape xla_shape;
370     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
371 
372     counter = MaybeSliceCounter(alg, counter_shape, counter);
373     auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape);
374     ctx->SetOutput(0, result.value);
375   }
376 
377  private:
378   DataType dtype_;
379   string device_type_string_;
380 
381   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
382 };
383 
384 REGISTER_XLA_OP(Name("StatelessRandomUniformFullIntV2")
385                     .CompileTimeConstantInput("shape")
386                     .CompileTimeConstantInput("alg")
387                     .TypeConstraint("dtype",
388                                     {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
389                 StatelessRandomUniformFullIntOp);
390 
391 class StatelessRandomNormalOp : public XlaOpKernel {
392  public:
StatelessRandomNormalOp(OpKernelConstruction * ctx)393   explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
394       : XlaOpKernel(ctx),
395         device_type_string_(ctx->device_type().type_string()) {
396     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
397   }
398 
Compile(XlaOpKernelContext * ctx)399   void Compile(XlaOpKernelContext* ctx) override {
400     TensorShape shape;
401     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(
402                             0, &shape, xla::ValueInferenceMode::kUpperBound));
403 
404     const int key_input_idx = 1;
405     const int counter_input_idx = 2;
406     const int alg_input_idx = 3;
407     xla::XlaOp key = ctx->Input(key_input_idx);
408     xla::XlaOp counter = ctx->Input(counter_input_idx);
409 
410     xla::RandomAlgorithm alg;
411     OP_REQUIRES_OK(
412         ctx, AlgorithmFromInput(ctx, alg_input_idx, device_type_string_, &alg));
413 
414     auto counter_shape = ctx->InputShape(counter_input_idx);
415     OP_REQUIRES_OK(ctx, CheckKeyCounterShape(XlaRngAlgToTensorFlow(alg),
416                                              ctx->InputShape(key_input_idx),
417                                              counter_shape));
418 
419     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
420 
421     xla::Shape xla_shape;
422     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
423 
424     using std::placeholders::_1;
425     using std::placeholders::_2;
426     using std::placeholders::_3;
427     auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
428     counter = MaybeSliceCounter(alg, counter_shape, counter);
429     auto result = xla::NormalFloatingPointDistribution(key, counter, generator,
430                                                        xla_shape);
431     auto normal = MaybeConvertF32ToBF16(result.value, dtype_);
432 
433     // If the input shape is constant, no need to set dimension sizes.
434     // TODO(hinsu): Simplify this once MLIR bridge can handle bounded types.
435     TensorShape static_shape;
436     Status status = ctx->ConstantInputAsShape(0, &static_shape);
437     if (status.ok()) {
438       ctx->SetOutput(0, normal);
439       return;
440     }
441 
442     auto result_or = xla::SetAllDimensionSizes(&ctx->value_inference(), normal,
443                                                ctx->Input(0));
444     OP_REQUIRES_OK(ctx, result_or.status());
445     ctx->SetOutput(0, result_or.ValueOrDie());
446   }
447 
448  private:
449   DataType dtype_;
450   string device_type_string_;
451 
452   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
453 };
454 
455 REGISTER_XLA_OP(Name("StatelessRandomNormalV2")
456                     .CompileTimeConstantInput("shape")
457                     .CompileTimeConstantInput("alg")
458                     .TypeConstraint("dtype",
459                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
460                 StatelessRandomNormalOp);
461 
462 class StatelessTruncatedNormalOp : public XlaOpKernel {
463  public:
StatelessTruncatedNormalOp(OpKernelConstruction * ctx)464   explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
465       : XlaOpKernel(ctx),
466         device_type_string_(ctx->device_type().type_string()) {
467     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
468   }
469 
Compile(XlaOpKernelContext * ctx)470   void Compile(XlaOpKernelContext* ctx) override {
471     TensorShape shape;
472     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
473 
474     const int key_input_idx = 1;
475     const int counter_input_idx = 2;
476     const int alg_input_idx = 3;
477     xla::XlaOp key = ctx->Input(key_input_idx);
478     xla::XlaOp counter = ctx->Input(counter_input_idx);
479 
480     xla::RandomAlgorithm alg;
481     OP_REQUIRES_OK(
482         ctx, AlgorithmFromInput(ctx, alg_input_idx, device_type_string_, &alg));
483 
484     auto counter_shape = ctx->InputShape(counter_input_idx);
485     OP_REQUIRES_OK(ctx, CheckKeyCounterShape(XlaRngAlgToTensorFlow(alg),
486                                              ctx->InputShape(key_input_idx),
487                                              counter_shape));
488 
489     xla::XlaBuilder* builder = ctx->builder();
490 
491     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
492     xla::Shape xla_shape;
493     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
494 
495     counter = MaybeSliceCounter(alg, counter_shape, counter);
496     auto result = StatelessRngUniformV2(
497         alg, key, counter, xla_shape,
498         xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
499         xla::One(builder, xla_shape.element_type()));
500     xla::XlaOp truncated_normal = TruncatedNormal(result.value);
501     truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
502     ctx->SetOutput(0, truncated_normal);
503   }
504 
505  private:
506   DataType dtype_;
507   string device_type_string_;
508 
509   TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
510 };
511 
512 REGISTER_XLA_OP(Name("StatelessTruncatedNormalV2")
513                     .CompileTimeConstantInput("shape")
514                     .CompileTimeConstantInput("alg")
515                     .TypeConstraint("dtype",
516                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
517                 StatelessTruncatedNormalOp);
518 
519 class GetKeyCounterAlgOp : public XlaOpKernel {
520  public:
GetKeyCounterAlgOp(OpKernelConstruction * ctx)521   explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx)
522       : XlaOpKernel(ctx),
523         device_type_string_(ctx->device_type().type_string()) {}
524 
Compile(XlaOpKernelContext * ctx)525   void Compile(XlaOpKernelContext* ctx) override {
526     TensorShape seed_shape = ctx->InputShape(0);
527     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
528                 errors::InvalidArgument("seed must have shape [2], not ",
529                                         seed_shape.DebugString()));
530     xla::XlaOp seed = ctx->Input(0);
531 
532     xla::XlaBuilder* builder = seed.builder();
533     xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
534     xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
535     xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
536     auto key_counter = GetKeyCounter(device_type_string_, key);
537     key = std::get<0>(key_counter);
538     auto counter = std::get<1>(key_counter);
539     auto alg = DefaultRngAlgForDeviceType(device_type_string_);
540     key = xla::Reshape(key, {RNG_KEY_SIZE});
541     ctx->SetOutput(0, key);
542     ctx->SetOutput(1, counter);
543     ctx->SetOutput(2, ConstantR0(builder, static_cast<int>(alg)));
544   }
545 
546  private:
547   string device_type_string_;
548 
549   TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp);
550 };
551 
552 // TODO(hinsu): Dis-allow unsupported int64 seed types.
553 REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
554 
555 class GetKeyCounterOp : public XlaOpKernel {
556  public:
GetKeyCounterOp(OpKernelConstruction * ctx)557   explicit GetKeyCounterOp(OpKernelConstruction* ctx)
558       : XlaOpKernel(ctx),
559         device_type_string_(ctx->device_type().type_string()) {}
560 
Compile(XlaOpKernelContext * ctx)561   void Compile(XlaOpKernelContext* ctx) override {
562     TensorShape seed_shape = ctx->InputShape(0);
563     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
564                 errors::InvalidArgument("seed must have shape [2], not ",
565                                         seed_shape.DebugString()));
566     xla::XlaOp seed = ctx->Input(0);
567 
568     xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
569     xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
570     xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
571     auto key_counter = GetKeyCounter(device_type_string_, key);
572     key = std::get<0>(key_counter);
573     auto counter = std::get<1>(key_counter);
574     key = xla::Reshape(key, {RNG_KEY_SIZE});
575     ctx->SetOutput(0, key);
576     ctx->SetOutput(1, counter);
577   }
578 
579  private:
580   string device_type_string_;
581 
582   TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterOp);
583 };
584 
585 // TODO(hinsu): Dis-allow unsupported int64 seed types.
586 REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounter"), GetKeyCounterOp);
587 
588 class GetAlgOp : public XlaOpKernel {
589  public:
GetAlgOp(OpKernelConstruction * ctx)590   explicit GetAlgOp(OpKernelConstruction* ctx)
591       : XlaOpKernel(ctx),
592         device_type_string_(ctx->device_type().type_string()) {}
593 
Compile(XlaOpKernelContext * ctx)594   void Compile(XlaOpKernelContext* ctx) override {
595     auto alg = DefaultRngAlgForDeviceType(device_type_string_);
596     auto builder = ctx->builder();
597     ctx->SetOutput(0, ConstantR0(builder, static_cast<int>(alg)));
598   }
599 
600  private:
601   string device_type_string_;
602 
603   TF_DISALLOW_COPY_AND_ASSIGN(GetAlgOp);
604 };
605 
606 REGISTER_XLA_OP(Name("StatelessRandomGetAlg"), GetAlgOp);
607 
608 REGISTER_XLA_OP(Name("XlaRngBitGenerator")
609                     .CompileTimeConstantInput("algorithm")
610                     .CompileTimeConstantInput("shape")
611                     .TypeConstraint("dtype", {DT_UINT32, DT_UINT64}),
612                 MlirXlaOpKernel);
613 
614 }  // namespace
615 }  // namespace tensorflow
616