• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/stateless_random_ops_v2.h"
17 
18 #include <cmath>
19 
20 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
21 #include "tensorflow/compiler/tf2xla/lib/random.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/lib/math.h"
29 #include "tensorflow/compiler/xla/client/lib/prng.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/rng_alg.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/lib/math/math_util.h"
37 
38 namespace tensorflow {
39 
40 namespace {
41 
AlgorithmToRandomAlgorithm(Algorithm const & alg)42 inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
43   if (alg == RNG_ALG_PHILOX) {
44     return xla::RandomAlgorithm::RNG_PHILOX;
45   } else if (alg == RNG_ALG_THREEFRY) {
46     return xla::RandomAlgorithm::RNG_THREE_FRY;
47   } else if (alg == RNG_ALG_XLA_DEFAULT) {
48     return xla::RandomAlgorithm::RNG_DEFAULT;
49   }
50   return xla::RandomAlgorithm::RNG_THREE_FRY;
51 }
52 
RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const & alg)53 inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) {
54   if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
55     return RNG_ALG_PHILOX;
56   } else if (alg == xla::RandomAlgorithm::RNG_THREE_FRY) {
57     return RNG_ALG_THREEFRY;
58   } else if (alg == xla::RandomAlgorithm::RNG_DEFAULT) {
59     return RNG_ALG_XLA_DEFAULT;
60   }
61   return RNG_ALG_THREEFRY;
62 }
63 
GetCounter(xla::RandomAlgorithm const & alg,xla::XlaOp state)64 xla::XlaOp GetCounter(xla::RandomAlgorithm const& alg, xla::XlaOp state) {
65   Algorithm alg_ = RandomAlgorithmToAlgorithm(alg);
66   return xla::Slice(state, {RNG_KEY_SIZE},
67                     {RNG_KEY_SIZE + GetCounterSize(alg_)}, {1});
68 }
69 
BitGenerator(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape)70 xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key,
71                             xla::XlaOp counter, const xla::Shape& shape) {
72   key = BitcastConvertType(key, xla::U64);
73   counter = BitcastConvertType(counter, xla::U64);
74   xla::XlaOp state = xla::ConcatInDim(key.builder(), {key, counter}, 0);
75   xla::XlaOp result = xla::RngBitGenerator(alg, state, shape);
76   auto new_counter = GetCounter(alg, xla::GetTupleElement(result, 0));
77   new_counter = BitcastConvertType(new_counter, xla::S64);
78   return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
79                         /*state=*/new_counter};
80 }
81 
GetKeyCounter(absl::string_view device_type_string,xla::XlaOp key)82 std::tuple<xla::XlaOp, xla::XlaOp> GetKeyCounter(
83     absl::string_view device_type_string, xla::XlaOp key) {
84   // The Philox algorithm may cause performance regression on other devices.
85   // Turn on the Philox algorithm for the CPU and GPU backends only.
86   if (device_type_string == DEVICE_GPU_XLA_JIT ||
87       device_type_string == DEVICE_CPU_XLA_JIT) {
88     auto counter_key = xla::ScramblePhiloxKey(key);
89     return std::make_tuple(counter_key.second, counter_key.first);
90   } else {
91     auto counter_shape =
92         xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
93     auto counter = xla::Zeros(key.builder(), counter_shape);
94     return std::make_tuple(key, counter);
95   }
96 }
97 
GetAlg(absl::string_view device_type_string)98 Algorithm GetAlg(absl::string_view device_type_string) {
99   // The Philox algorithm may cause performance regression on other devices.
100   // Turn on the Philox algorithm for the CPU and GPU backends only.
101   if (device_type_string == DEVICE_GPU_XLA_JIT ||
102       device_type_string == DEVICE_CPU_XLA_JIT) {
103     return RNG_ALG_PHILOX;
104   } else {
105     return RNG_ALG_XLA_DEFAULT;
106   }
107 }
108 
109 }  // namespace
110 
StatelessRngUniformV2(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)111 xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg,
112                                    xla::XlaOp key, xla::XlaOp counter,
113                                    const xla::Shape& shape, xla::XlaOp minval,
114                                    xla::XlaOp maxval) {
115   xla::XlaBuilder* builder = key.builder();
116   xla::PrimitiveType type = shape.element_type();
117   using std::placeholders::_1;
118   using std::placeholders::_2;
119   using std::placeholders::_3;
120   auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
121   switch (type) {
122     case xla::F32:
123     case xla::F64:
124       return xla::UniformFloatingPointDistribution(key, counter, generator,
125                                                    minval, maxval, shape);
126     case xla::S32:
127     case xla::S64:
128     case xla::U32:
129     case xla::U64:
130       return UniformIntDistribution(key, counter, generator, minval, maxval,
131                                     shape);
132       break;
133     default:
134       return {builder->ReportError(xla::Unimplemented(
135                   "Types other than F32, S32, S64, U32 and U64 are not "
136                   "implemented by "
137                   "StatelessRngUniformV2; got %s",
138                   xla::primitive_util::LowercasePrimitiveTypeName(type))),
139               counter};
140   }
141 }
142 
143 namespace {
144 
StatelessRngUniformFullInt(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape)145 xla::RngOutput StatelessRngUniformFullInt(xla::RandomAlgorithm const& alg,
146                                           xla::XlaOp key, xla::XlaOp counter,
147                                           const xla::Shape& shape) {
148   xla::XlaBuilder* builder = key.builder();
149 
150   xla::PrimitiveType type = shape.element_type();
151   xla::RngOutput output = BitGenerator(alg, key, counter, shape);
152   switch (type) {
153     case xla::U32:
154     case xla::U64:
155       return output;
156     case xla::S32:
157     case xla::S64:
158       return xla::RngOutput{BitcastConvertType(output.value, type),
159                             output.state};
160     default:
161       return {
162           builder->ReportError(xla::Unimplemented(
163               "Types other than U32, S32, U64 and S64 are not implemented by "
164               "StatelessRngUniformFullInt; got: %s",
165               xla::primitive_util::LowercasePrimitiveTypeName(type))),
166           output.state};
167   }
168 }
169 
GetAlgorithm(XlaOpKernelContext * ctx,int alg_input_idx,xla::RandomAlgorithm * alg)170 Status GetAlgorithm(XlaOpKernelContext* ctx, int alg_input_idx,
171                     xla::RandomAlgorithm* alg) {
172   auto alg_shape = ctx->InputShape(alg_input_idx);
173   if (alg_shape.dims() != 0) {
174     return errors::InvalidArgument("algorithm must be of shape [], not ",
175                                    alg_shape.DebugString());
176   }
177   xla::Literal alg_literal;
178   TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
179   auto alg_ = Algorithm(alg_literal.Get<int>({}));
180   *alg = AlgorithmToRandomAlgorithm(alg_);
181   return Status::OK();
182 }
183 
MaybeSliceCounter(xla::RandomAlgorithm const & alg,TensorShape const & counter_shape,xla::XlaOp counter)184 xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg,
185                              TensorShape const& counter_shape,
186                              xla::XlaOp counter) {
187   auto input_counter_size = counter_shape.dim_size(0);
188   auto real_counter_size = GetCounterSize(RandomAlgorithmToAlgorithm(alg));
189   if (input_counter_size > real_counter_size) {
190     counter = xla::Slice(counter, {0}, {real_counter_size}, {1});
191   }
192   return counter;
193 }
194 
195 class StatelessRandomUniformOp : public XlaOpKernel {
196  public:
StatelessRandomUniformOp(OpKernelConstruction * ctx)197   explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
198       : XlaOpKernel(ctx) {
199     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
200   }
201 
Compile(XlaOpKernelContext * ctx)202   void Compile(XlaOpKernelContext* ctx) override {
203     xla::XlaBuilder* builder = ctx->builder();
204 
205     TensorShape shape;
206     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
207 
208     const int key_input_idx = 1;
209     const int counter_input_idx = 2;
210     const int alg_input_idx = 3;
211     xla::XlaOp key = ctx->Input(key_input_idx);
212     xla::XlaOp counter = ctx->Input(counter_input_idx);
213 
214     xla::RandomAlgorithm alg;
215     OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
216 
217     auto counter_shape = ctx->InputShape(counter_input_idx);
218     OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
219                                              ctx->InputShape(key_input_idx),
220                                              counter_shape));
221 
222     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
223     xla::Shape xla_shape;
224     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
225     xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
226 
227     counter = MaybeSliceCounter(alg, counter_shape, counter);
228 
229     auto result = StatelessRngUniformV2(
230         alg, key, counter, xla_shape,
231         xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
232         xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
233     auto uniform = MaybeConvertF32ToBF16(result.value, dtype_);
234     ctx->SetOutput(0, uniform);
235   }
236 
237  private:
238   DataType dtype_;
239 
240   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
241 };
242 
243 REGISTER_XLA_OP(Name("StatelessRandomUniformV2")
244                     .CompileTimeConstantInput("shape")
245                     .CompileTimeConstantInput("alg")
246                     .TypeConstraint("dtype",
247                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
248                 StatelessRandomUniformOp);
249 
250 class StatelessRandomUniformIntOp : public XlaOpKernel {
251  public:
StatelessRandomUniformIntOp(OpKernelConstruction * ctx)252   explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
253       : XlaOpKernel(ctx) {
254     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
255   }
256 
Compile(XlaOpKernelContext * ctx)257   void Compile(XlaOpKernelContext* ctx) override {
258     TensorShape shape;
259     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
260 
261     const int key_input_idx = 1;
262     const int counter_input_idx = 2;
263     const int alg_input_idx = 3;
264     xla::XlaOp key = ctx->Input(key_input_idx);
265     xla::XlaOp counter = ctx->Input(counter_input_idx);
266 
267     xla::RandomAlgorithm alg;
268     OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
269 
270     auto counter_shape = ctx->InputShape(counter_input_idx);
271     OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
272                                              ctx->InputShape(key_input_idx),
273                                              counter_shape));
274 
275     const int minval_input_idx = 4;
276     const int maxval_input_idx = 5;
277     TensorShape minval_shape = ctx->InputShape(minval_input_idx);
278     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
279                 errors::InvalidArgument("minval must be scalar, got shape ",
280                                         minval_shape.DebugString()));
281     TensorShape maxval_shape = ctx->InputShape(maxval_input_idx);
282     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
283                 errors::InvalidArgument("maxval must be scalar, got shape ",
284                                         maxval_shape.DebugString()));
285 
286     xla::XlaOp minval = ctx->Input(minval_input_idx);
287     xla::XlaOp maxval = ctx->Input(maxval_input_idx);
288 
289     xla::Shape xla_shape;
290     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
291 
292     counter = MaybeSliceCounter(alg, counter_shape, counter);
293     auto result =
294         StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval);
295     ctx->SetOutput(0, result.value);
296   }
297 
298  private:
299   DataType dtype_;
300 
301   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
302 };
303 
304 REGISTER_XLA_OP(Name("StatelessRandomUniformIntV2")
305                     .CompileTimeConstantInput("shape")
306                     .CompileTimeConstantInput("alg")
307                     .TypeConstraint("dtype",
308                                     {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
309                 StatelessRandomUniformIntOp);
310 
311 class StatelessRandomUniformFullIntOp : public XlaOpKernel {
312  public:
StatelessRandomUniformFullIntOp(OpKernelConstruction * ctx)313   explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
314       : XlaOpKernel(ctx) {
315     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
316   }
317 
Compile(XlaOpKernelContext * ctx)318   void Compile(XlaOpKernelContext* ctx) override {
319     TensorShape shape;
320     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
321 
322     const int key_input_idx = 1;
323     const int counter_input_idx = 2;
324     const int alg_input_idx = 3;
325     xla::XlaOp key = ctx->Input(key_input_idx);
326     xla::XlaOp counter = ctx->Input(counter_input_idx);
327 
328     xla::RandomAlgorithm alg;
329     OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
330 
331     auto counter_shape = ctx->InputShape(counter_input_idx);
332     OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
333                                              ctx->InputShape(key_input_idx),
334                                              counter_shape));
335 
336     xla::Shape xla_shape;
337     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
338 
339     counter = MaybeSliceCounter(alg, counter_shape, counter);
340     auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape);
341     ctx->SetOutput(0, result.value);
342   }
343 
344  private:
345   DataType dtype_;
346 
347   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
348 };
349 
350 REGISTER_XLA_OP(Name("StatelessRandomUniformFullIntV2")
351                     .CompileTimeConstantInput("shape")
352                     .CompileTimeConstantInput("alg")
353                     .TypeConstraint("dtype",
354                                     {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
355                 StatelessRandomUniformFullIntOp);
356 
357 class StatelessRandomNormalOp : public XlaOpKernel {
358  public:
StatelessRandomNormalOp(OpKernelConstruction * ctx)359   explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
360       : XlaOpKernel(ctx) {
361     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
362   }
363 
Compile(XlaOpKernelContext * ctx)364   void Compile(XlaOpKernelContext* ctx) override {
365     TensorShape shape;
366     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
367 
368     const int key_input_idx = 1;
369     const int counter_input_idx = 2;
370     const int alg_input_idx = 3;
371     xla::XlaOp key = ctx->Input(key_input_idx);
372     xla::XlaOp counter = ctx->Input(counter_input_idx);
373 
374     xla::RandomAlgorithm alg;
375     OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
376 
377     auto counter_shape = ctx->InputShape(counter_input_idx);
378     OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
379                                              ctx->InputShape(key_input_idx),
380                                              counter_shape));
381 
382     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
383 
384     xla::Shape xla_shape;
385     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
386 
387     using std::placeholders::_1;
388     using std::placeholders::_2;
389     using std::placeholders::_3;
390     auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
391     counter = MaybeSliceCounter(alg, counter_shape, counter);
392     auto result = xla::NormalFloatingPointDistribution(key, counter, generator,
393                                                        xla_shape);
394     auto normal = MaybeConvertF32ToBF16(result.value, dtype_);
395     ctx->SetOutput(0, normal);
396   }
397 
398  private:
399   DataType dtype_;
400 
401   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
402 };
403 
404 REGISTER_XLA_OP(Name("StatelessRandomNormalV2")
405                     .CompileTimeConstantInput("shape")
406                     .CompileTimeConstantInput("alg")
407                     .TypeConstraint("dtype",
408                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
409                 StatelessRandomNormalOp);
410 
411 class StatelessTruncatedNormalOp : public XlaOpKernel {
412  public:
StatelessTruncatedNormalOp(OpKernelConstruction * ctx)413   explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
414       : XlaOpKernel(ctx) {
415     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
416   }
417 
Compile(XlaOpKernelContext * ctx)418   void Compile(XlaOpKernelContext* ctx) override {
419     TensorShape shape;
420     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
421 
422     const int key_input_idx = 1;
423     const int counter_input_idx = 2;
424     const int alg_input_idx = 3;
425     xla::XlaOp key = ctx->Input(key_input_idx);
426     xla::XlaOp counter = ctx->Input(counter_input_idx);
427 
428     xla::RandomAlgorithm alg;
429     OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
430 
431     auto counter_shape = ctx->InputShape(counter_input_idx);
432     OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
433                                              ctx->InputShape(key_input_idx),
434                                              counter_shape));
435 
436     xla::XlaBuilder* builder = ctx->builder();
437 
438     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
439     xla::Shape xla_shape;
440     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
441 
442     counter = MaybeSliceCounter(alg, counter_shape, counter);
443     auto result = StatelessRngUniformV2(
444         alg, key, counter, xla_shape,
445         xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
446         xla::One(builder, xla_shape.element_type()));
447     xla::XlaOp truncated_normal = TruncatedNormal(result.value);
448     truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
449     ctx->SetOutput(0, truncated_normal);
450   }
451 
452  private:
453   DataType dtype_;
454 
455   TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
456 };
457 
458 REGISTER_XLA_OP(Name("StatelessTruncatedNormalV2")
459                     .CompileTimeConstantInput("shape")
460                     .CompileTimeConstantInput("alg")
461                     .TypeConstraint("dtype",
462                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
463                 StatelessTruncatedNormalOp);
464 
465 class GetKeyCounterAlgOp : public XlaOpKernel {
466  public:
GetKeyCounterAlgOp(OpKernelConstruction * ctx)467   explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx)
468       : XlaOpKernel(ctx),
469         device_type_string_(ctx->device_type().type_string()) {}
470 
Compile(XlaOpKernelContext * ctx)471   void Compile(XlaOpKernelContext* ctx) override {
472     TensorShape seed_shape = ctx->InputShape(0);
473     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
474                 errors::InvalidArgument("seed must have shape [2], not ",
475                                         seed_shape.DebugString()));
476     xla::XlaOp seed = ctx->Input(0);
477 
478     xla::XlaBuilder* builder = seed.builder();
479     xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
480     xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
481     xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
482     auto key_counter = GetKeyCounter(device_type_string_, key);
483     key = std::get<0>(key_counter);
484     auto counter = std::get<1>(key_counter);
485     auto alg = GetAlg(device_type_string_);
486     key = xla::Reshape(key, {RNG_KEY_SIZE});
487     ctx->SetOutput(0, key);
488     ctx->SetOutput(1, counter);
489     ctx->SetOutput(2, ConstantR0(builder, static_cast<int>(alg)));
490   }
491 
492  private:
493   string device_type_string_;
494 
495   TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp);
496 };
497 
498 // TODO(hinsu): Dis-allow unsupported int64 seed types.
499 REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
500 
501 class GetKeyCounterOp : public XlaOpKernel {
502  public:
GetKeyCounterOp(OpKernelConstruction * ctx)503   explicit GetKeyCounterOp(OpKernelConstruction* ctx)
504       : XlaOpKernel(ctx),
505         device_type_string_(ctx->device_type().type_string()) {}
506 
Compile(XlaOpKernelContext * ctx)507   void Compile(XlaOpKernelContext* ctx) override {
508     TensorShape seed_shape = ctx->InputShape(0);
509     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
510                 errors::InvalidArgument("seed must have shape [2], not ",
511                                         seed_shape.DebugString()));
512     xla::XlaOp seed = ctx->Input(0);
513 
514     xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
515     xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
516     xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
517     auto key_counter = GetKeyCounter(device_type_string_, key);
518     key = std::get<0>(key_counter);
519     auto counter = std::get<1>(key_counter);
520     key = xla::Reshape(key, {RNG_KEY_SIZE});
521     ctx->SetOutput(0, key);
522     ctx->SetOutput(1, counter);
523   }
524 
525  private:
526   string device_type_string_;
527 
528   TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterOp);
529 };
530 
531 // TODO(hinsu): Dis-allow unsupported int64 seed types.
532 REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounter"), GetKeyCounterOp);
533 
534 class GetAlgOp : public XlaOpKernel {
535  public:
GetAlgOp(OpKernelConstruction * ctx)536   explicit GetAlgOp(OpKernelConstruction* ctx)
537       : XlaOpKernel(ctx),
538         device_type_string_(ctx->device_type().type_string()) {}
539 
Compile(XlaOpKernelContext * ctx)540   void Compile(XlaOpKernelContext* ctx) override {
541     auto alg = GetAlg(device_type_string_);
542     auto builder = ctx->builder();
543     ctx->SetOutput(0, ConstantR0(builder, static_cast<int>(alg)));
544   }
545 
546  private:
547   string device_type_string_;
548 
549   TF_DISALLOW_COPY_AND_ASSIGN(GetAlgOp);
550 };
551 
552 REGISTER_XLA_OP(Name("StatelessRandomGetAlg"), GetAlgOp);
553 
554 }  // namespace
555 }  // namespace tensorflow
556