1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/stateless_random_ops_v2.h"
17
18 #include <cmath>
19
20 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
21 #include "tensorflow/compiler/tf2xla/lib/random.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/lib/math.h"
29 #include "tensorflow/compiler/xla/client/lib/prng.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/rng_alg.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/lib/math/math_util.h"
37
38 namespace tensorflow {
39
40 namespace {
41
AlgorithmToRandomAlgorithm(Algorithm const & alg)42 inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
43 if (alg == RNG_ALG_PHILOX) {
44 return xla::RandomAlgorithm::RNG_PHILOX;
45 } else if (alg == RNG_ALG_THREEFRY) {
46 return xla::RandomAlgorithm::RNG_THREE_FRY;
47 } else if (alg == RNG_ALG_XLA_DEFAULT) {
48 return xla::RandomAlgorithm::RNG_DEFAULT;
49 }
50 return xla::RandomAlgorithm::RNG_THREE_FRY;
51 }
52
RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const & alg)53 inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) {
54 if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
55 return RNG_ALG_PHILOX;
56 } else if (alg == xla::RandomAlgorithm::RNG_THREE_FRY) {
57 return RNG_ALG_THREEFRY;
58 } else if (alg == xla::RandomAlgorithm::RNG_DEFAULT) {
59 return RNG_ALG_XLA_DEFAULT;
60 }
61 return RNG_ALG_THREEFRY;
62 }
63
GetCounter(xla::RandomAlgorithm const & alg,xla::XlaOp state)64 xla::XlaOp GetCounter(xla::RandomAlgorithm const& alg, xla::XlaOp state) {
65 Algorithm alg_ = RandomAlgorithmToAlgorithm(alg);
66 return xla::Slice(state, {RNG_KEY_SIZE},
67 {RNG_KEY_SIZE + GetCounterSize(alg_)}, {1});
68 }
69
BitGenerator(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape)70 xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key,
71 xla::XlaOp counter, const xla::Shape& shape) {
72 key = BitcastConvertType(key, xla::U64);
73 counter = BitcastConvertType(counter, xla::U64);
74 xla::XlaOp state = xla::ConcatInDim(key.builder(), {key, counter}, 0);
75 xla::XlaOp result = xla::RngBitGenerator(alg, state, shape);
76 auto new_counter = GetCounter(alg, xla::GetTupleElement(result, 0));
77 new_counter = BitcastConvertType(new_counter, xla::S64);
78 return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
79 /*state=*/new_counter};
80 }
81
GetKeyCounter(absl::string_view device_type_string,xla::XlaOp key)82 std::tuple<xla::XlaOp, xla::XlaOp> GetKeyCounter(
83 absl::string_view device_type_string, xla::XlaOp key) {
84 // The Philox algorithm may cause performance regression on other devices.
85 // Turn on the Philox algorithm for the CPU and GPU backends only.
86 if (device_type_string == DEVICE_GPU_XLA_JIT ||
87 device_type_string == DEVICE_CPU_XLA_JIT) {
88 auto counter_key = xla::ScramblePhiloxKey(key);
89 return std::make_tuple(counter_key.second, counter_key.first);
90 } else {
91 auto counter_shape =
92 xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
93 auto counter = xla::Zeros(key.builder(), counter_shape);
94 return std::make_tuple(key, counter);
95 }
96 }
97
GetAlg(absl::string_view device_type_string)98 Algorithm GetAlg(absl::string_view device_type_string) {
99 // The Philox algorithm may cause performance regression on other devices.
100 // Turn on the Philox algorithm for the CPU and GPU backends only.
101 if (device_type_string == DEVICE_GPU_XLA_JIT ||
102 device_type_string == DEVICE_CPU_XLA_JIT) {
103 return RNG_ALG_PHILOX;
104 } else {
105 return RNG_ALG_XLA_DEFAULT;
106 }
107 }
108
109 } // namespace
110
StatelessRngUniformV2(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)111 xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg,
112 xla::XlaOp key, xla::XlaOp counter,
113 const xla::Shape& shape, xla::XlaOp minval,
114 xla::XlaOp maxval) {
115 xla::XlaBuilder* builder = key.builder();
116 xla::PrimitiveType type = shape.element_type();
117 using std::placeholders::_1;
118 using std::placeholders::_2;
119 using std::placeholders::_3;
120 auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
121 switch (type) {
122 case xla::F32:
123 case xla::F64:
124 return xla::UniformFloatingPointDistribution(key, counter, generator,
125 minval, maxval, shape);
126 case xla::S32:
127 case xla::S64:
128 case xla::U32:
129 case xla::U64:
130 return UniformIntDistribution(key, counter, generator, minval, maxval,
131 shape);
132 break;
133 default:
134 return {builder->ReportError(xla::Unimplemented(
135 "Types other than F32, S32, S64, U32 and U64 are not "
136 "implemented by "
137 "StatelessRngUniformV2; got %s",
138 xla::primitive_util::LowercasePrimitiveTypeName(type))),
139 counter};
140 }
141 }
142
143 namespace {
144
StatelessRngUniformFullInt(xla::RandomAlgorithm const & alg,xla::XlaOp key,xla::XlaOp counter,const xla::Shape & shape)145 xla::RngOutput StatelessRngUniformFullInt(xla::RandomAlgorithm const& alg,
146 xla::XlaOp key, xla::XlaOp counter,
147 const xla::Shape& shape) {
148 xla::XlaBuilder* builder = key.builder();
149
150 xla::PrimitiveType type = shape.element_type();
151 xla::RngOutput output = BitGenerator(alg, key, counter, shape);
152 switch (type) {
153 case xla::U32:
154 case xla::U64:
155 return output;
156 case xla::S32:
157 case xla::S64:
158 return xla::RngOutput{BitcastConvertType(output.value, type),
159 output.state};
160 default:
161 return {
162 builder->ReportError(xla::Unimplemented(
163 "Types other than U32, S32, U64 and S64 are not implemented by "
164 "StatelessRngUniformFullInt; got: %s",
165 xla::primitive_util::LowercasePrimitiveTypeName(type))),
166 output.state};
167 }
168 }
169
GetAlgorithm(XlaOpKernelContext * ctx,int alg_input_idx,xla::RandomAlgorithm * alg)170 Status GetAlgorithm(XlaOpKernelContext* ctx, int alg_input_idx,
171 xla::RandomAlgorithm* alg) {
172 auto alg_shape = ctx->InputShape(alg_input_idx);
173 if (alg_shape.dims() != 0) {
174 return errors::InvalidArgument("algorithm must be of shape [], not ",
175 alg_shape.DebugString());
176 }
177 xla::Literal alg_literal;
178 TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
179 auto alg_ = Algorithm(alg_literal.Get<int>({}));
180 *alg = AlgorithmToRandomAlgorithm(alg_);
181 return Status::OK();
182 }
183
MaybeSliceCounter(xla::RandomAlgorithm const & alg,TensorShape const & counter_shape,xla::XlaOp counter)184 xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg,
185 TensorShape const& counter_shape,
186 xla::XlaOp counter) {
187 auto input_counter_size = counter_shape.dim_size(0);
188 auto real_counter_size = GetCounterSize(RandomAlgorithmToAlgorithm(alg));
189 if (input_counter_size > real_counter_size) {
190 counter = xla::Slice(counter, {0}, {real_counter_size}, {1});
191 }
192 return counter;
193 }
194
195 class StatelessRandomUniformOp : public XlaOpKernel {
196 public:
StatelessRandomUniformOp(OpKernelConstruction * ctx)197 explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
198 : XlaOpKernel(ctx) {
199 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
200 }
201
Compile(XlaOpKernelContext * ctx)202 void Compile(XlaOpKernelContext* ctx) override {
203 xla::XlaBuilder* builder = ctx->builder();
204
205 TensorShape shape;
206 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
207
208 const int key_input_idx = 1;
209 const int counter_input_idx = 2;
210 const int alg_input_idx = 3;
211 xla::XlaOp key = ctx->Input(key_input_idx);
212 xla::XlaOp counter = ctx->Input(counter_input_idx);
213
214 xla::RandomAlgorithm alg;
215 OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
216
217 auto counter_shape = ctx->InputShape(counter_input_idx);
218 OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
219 ctx->InputShape(key_input_idx),
220 counter_shape));
221
222 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
223 xla::Shape xla_shape;
224 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
225 xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
226
227 counter = MaybeSliceCounter(alg, counter_shape, counter);
228
229 auto result = StatelessRngUniformV2(
230 alg, key, counter, xla_shape,
231 xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
232 xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
233 auto uniform = MaybeConvertF32ToBF16(result.value, dtype_);
234 ctx->SetOutput(0, uniform);
235 }
236
237 private:
238 DataType dtype_;
239
240 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
241 };
242
243 REGISTER_XLA_OP(Name("StatelessRandomUniformV2")
244 .CompileTimeConstantInput("shape")
245 .CompileTimeConstantInput("alg")
246 .TypeConstraint("dtype",
247 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
248 StatelessRandomUniformOp);
249
250 class StatelessRandomUniformIntOp : public XlaOpKernel {
251 public:
StatelessRandomUniformIntOp(OpKernelConstruction * ctx)252 explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
253 : XlaOpKernel(ctx) {
254 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
255 }
256
Compile(XlaOpKernelContext * ctx)257 void Compile(XlaOpKernelContext* ctx) override {
258 TensorShape shape;
259 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
260
261 const int key_input_idx = 1;
262 const int counter_input_idx = 2;
263 const int alg_input_idx = 3;
264 xla::XlaOp key = ctx->Input(key_input_idx);
265 xla::XlaOp counter = ctx->Input(counter_input_idx);
266
267 xla::RandomAlgorithm alg;
268 OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
269
270 auto counter_shape = ctx->InputShape(counter_input_idx);
271 OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
272 ctx->InputShape(key_input_idx),
273 counter_shape));
274
275 const int minval_input_idx = 4;
276 const int maxval_input_idx = 5;
277 TensorShape minval_shape = ctx->InputShape(minval_input_idx);
278 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
279 errors::InvalidArgument("minval must be scalar, got shape ",
280 minval_shape.DebugString()));
281 TensorShape maxval_shape = ctx->InputShape(maxval_input_idx);
282 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
283 errors::InvalidArgument("maxval must be scalar, got shape ",
284 maxval_shape.DebugString()));
285
286 xla::XlaOp minval = ctx->Input(minval_input_idx);
287 xla::XlaOp maxval = ctx->Input(maxval_input_idx);
288
289 xla::Shape xla_shape;
290 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
291
292 counter = MaybeSliceCounter(alg, counter_shape, counter);
293 auto result =
294 StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval);
295 ctx->SetOutput(0, result.value);
296 }
297
298 private:
299 DataType dtype_;
300
301 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
302 };
303
304 REGISTER_XLA_OP(Name("StatelessRandomUniformIntV2")
305 .CompileTimeConstantInput("shape")
306 .CompileTimeConstantInput("alg")
307 .TypeConstraint("dtype",
308 {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
309 StatelessRandomUniformIntOp);
310
311 class StatelessRandomUniformFullIntOp : public XlaOpKernel {
312 public:
StatelessRandomUniformFullIntOp(OpKernelConstruction * ctx)313 explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
314 : XlaOpKernel(ctx) {
315 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
316 }
317
Compile(XlaOpKernelContext * ctx)318 void Compile(XlaOpKernelContext* ctx) override {
319 TensorShape shape;
320 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
321
322 const int key_input_idx = 1;
323 const int counter_input_idx = 2;
324 const int alg_input_idx = 3;
325 xla::XlaOp key = ctx->Input(key_input_idx);
326 xla::XlaOp counter = ctx->Input(counter_input_idx);
327
328 xla::RandomAlgorithm alg;
329 OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
330
331 auto counter_shape = ctx->InputShape(counter_input_idx);
332 OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
333 ctx->InputShape(key_input_idx),
334 counter_shape));
335
336 xla::Shape xla_shape;
337 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
338
339 counter = MaybeSliceCounter(alg, counter_shape, counter);
340 auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape);
341 ctx->SetOutput(0, result.value);
342 }
343
344 private:
345 DataType dtype_;
346
347 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
348 };
349
350 REGISTER_XLA_OP(Name("StatelessRandomUniformFullIntV2")
351 .CompileTimeConstantInput("shape")
352 .CompileTimeConstantInput("alg")
353 .TypeConstraint("dtype",
354 {DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
355 StatelessRandomUniformFullIntOp);
356
357 class StatelessRandomNormalOp : public XlaOpKernel {
358 public:
StatelessRandomNormalOp(OpKernelConstruction * ctx)359 explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
360 : XlaOpKernel(ctx) {
361 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
362 }
363
Compile(XlaOpKernelContext * ctx)364 void Compile(XlaOpKernelContext* ctx) override {
365 TensorShape shape;
366 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
367
368 const int key_input_idx = 1;
369 const int counter_input_idx = 2;
370 const int alg_input_idx = 3;
371 xla::XlaOp key = ctx->Input(key_input_idx);
372 xla::XlaOp counter = ctx->Input(counter_input_idx);
373
374 xla::RandomAlgorithm alg;
375 OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
376
377 auto counter_shape = ctx->InputShape(counter_input_idx);
378 OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
379 ctx->InputShape(key_input_idx),
380 counter_shape));
381
382 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
383
384 xla::Shape xla_shape;
385 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
386
387 using std::placeholders::_1;
388 using std::placeholders::_2;
389 using std::placeholders::_3;
390 auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
391 counter = MaybeSliceCounter(alg, counter_shape, counter);
392 auto result = xla::NormalFloatingPointDistribution(key, counter, generator,
393 xla_shape);
394 auto normal = MaybeConvertF32ToBF16(result.value, dtype_);
395 ctx->SetOutput(0, normal);
396 }
397
398 private:
399 DataType dtype_;
400
401 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
402 };
403
404 REGISTER_XLA_OP(Name("StatelessRandomNormalV2")
405 .CompileTimeConstantInput("shape")
406 .CompileTimeConstantInput("alg")
407 .TypeConstraint("dtype",
408 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
409 StatelessRandomNormalOp);
410
411 class StatelessTruncatedNormalOp : public XlaOpKernel {
412 public:
StatelessTruncatedNormalOp(OpKernelConstruction * ctx)413 explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
414 : XlaOpKernel(ctx) {
415 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
416 }
417
Compile(XlaOpKernelContext * ctx)418 void Compile(XlaOpKernelContext* ctx) override {
419 TensorShape shape;
420 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
421
422 const int key_input_idx = 1;
423 const int counter_input_idx = 2;
424 const int alg_input_idx = 3;
425 xla::XlaOp key = ctx->Input(key_input_idx);
426 xla::XlaOp counter = ctx->Input(counter_input_idx);
427
428 xla::RandomAlgorithm alg;
429 OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
430
431 auto counter_shape = ctx->InputShape(counter_input_idx);
432 OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
433 ctx->InputShape(key_input_idx),
434 counter_shape));
435
436 xla::XlaBuilder* builder = ctx->builder();
437
438 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
439 xla::Shape xla_shape;
440 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
441
442 counter = MaybeSliceCounter(alg, counter_shape, counter);
443 auto result = StatelessRngUniformV2(
444 alg, key, counter, xla_shape,
445 xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
446 xla::One(builder, xla_shape.element_type()));
447 xla::XlaOp truncated_normal = TruncatedNormal(result.value);
448 truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
449 ctx->SetOutput(0, truncated_normal);
450 }
451
452 private:
453 DataType dtype_;
454
455 TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
456 };
457
458 REGISTER_XLA_OP(Name("StatelessTruncatedNormalV2")
459 .CompileTimeConstantInput("shape")
460 .CompileTimeConstantInput("alg")
461 .TypeConstraint("dtype",
462 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
463 StatelessTruncatedNormalOp);
464
465 class GetKeyCounterAlgOp : public XlaOpKernel {
466 public:
GetKeyCounterAlgOp(OpKernelConstruction * ctx)467 explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx)
468 : XlaOpKernel(ctx),
469 device_type_string_(ctx->device_type().type_string()) {}
470
Compile(XlaOpKernelContext * ctx)471 void Compile(XlaOpKernelContext* ctx) override {
472 TensorShape seed_shape = ctx->InputShape(0);
473 OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
474 errors::InvalidArgument("seed must have shape [2], not ",
475 seed_shape.DebugString()));
476 xla::XlaOp seed = ctx->Input(0);
477
478 xla::XlaBuilder* builder = seed.builder();
479 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
480 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
481 xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
482 auto key_counter = GetKeyCounter(device_type_string_, key);
483 key = std::get<0>(key_counter);
484 auto counter = std::get<1>(key_counter);
485 auto alg = GetAlg(device_type_string_);
486 key = xla::Reshape(key, {RNG_KEY_SIZE});
487 ctx->SetOutput(0, key);
488 ctx->SetOutput(1, counter);
489 ctx->SetOutput(2, ConstantR0(builder, static_cast<int>(alg)));
490 }
491
492 private:
493 string device_type_string_;
494
495 TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp);
496 };
497
498 // TODO(hinsu): Dis-allow unsupported int64 seed types.
499 REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
500
501 class GetKeyCounterOp : public XlaOpKernel {
502 public:
GetKeyCounterOp(OpKernelConstruction * ctx)503 explicit GetKeyCounterOp(OpKernelConstruction* ctx)
504 : XlaOpKernel(ctx),
505 device_type_string_(ctx->device_type().type_string()) {}
506
Compile(XlaOpKernelContext * ctx)507 void Compile(XlaOpKernelContext* ctx) override {
508 TensorShape seed_shape = ctx->InputShape(0);
509 OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
510 errors::InvalidArgument("seed must have shape [2], not ",
511 seed_shape.DebugString()));
512 xla::XlaOp seed = ctx->Input(0);
513
514 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
515 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
516 xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
517 auto key_counter = GetKeyCounter(device_type_string_, key);
518 key = std::get<0>(key_counter);
519 auto counter = std::get<1>(key_counter);
520 key = xla::Reshape(key, {RNG_KEY_SIZE});
521 ctx->SetOutput(0, key);
522 ctx->SetOutput(1, counter);
523 }
524
525 private:
526 string device_type_string_;
527
528 TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterOp);
529 };
530
531 // TODO(hinsu): Dis-allow unsupported int64 seed types.
532 REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounter"), GetKeyCounterOp);
533
534 class GetAlgOp : public XlaOpKernel {
535 public:
GetAlgOp(OpKernelConstruction * ctx)536 explicit GetAlgOp(OpKernelConstruction* ctx)
537 : XlaOpKernel(ctx),
538 device_type_string_(ctx->device_type().type_string()) {}
539
Compile(XlaOpKernelContext * ctx)540 void Compile(XlaOpKernelContext* ctx) override {
541 auto alg = GetAlg(device_type_string_);
542 auto builder = ctx->builder();
543 ctx->SetOutput(0, ConstantR0(builder, static_cast<int>(alg)));
544 }
545
546 private:
547 string device_type_string_;
548
549 TF_DISALLOW_COPY_AND_ASSIGN(GetAlgOp);
550 };
551
552 REGISTER_XLA_OP(Name("StatelessRandomGetAlg"), GetAlgOp);
553
554 } // namespace
555 } // namespace tensorflow
556