• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/stateful_random_ops.h"
17 
18 #include <cmath>
19 #include <utility>
20 
21 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
22 #include "tensorflow/compiler/tf2xla/lib/random.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/lib/constants.h"
29 #include "tensorflow/compiler/xla/client/lib/math.h"
30 #include "tensorflow/compiler/xla/client/lib/prng.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/rng_alg.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/tensor_shape.h"
37 #include "tensorflow/core/lib/math/math_util.h"
38 
39 namespace tensorflow {
40 namespace {
41 
BitGen(Algorithm alg)42 xla::BitGeneratorTy BitGen(Algorithm alg) {
43   if (alg == RNG_ALG_PHILOX) {
44     return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
45       state =
46           xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
47       xla::XlaOp result =
48           xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, state, shape);
49       xla::XlaOp data = xla::GetTupleElement(result, 1);
50       xla::XlaOp new_state =
51           xla::Slice(xla::GetTupleElement(result, 0), {1}, {3}, {1});
52       return xla::RngOutput{data, new_state};
53     };
54   } else {
55     return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
56       state = xla::ConcatScalars(key.builder(), {key, state});
57       xla::XlaOp result = xla::RngBitGenerator(
58           xla::RandomAlgorithm::RNG_THREE_FRY, state, shape);
59       xla::XlaOp data = xla::GetTupleElement(result, 1);
60       xla::XlaOp new_state = xla::Reshape(
61           xla::Slice(xla::GetTupleElement(result, 0), {1}, {2}, {1}), {});
62       return xla::RngOutput{data, new_state};
63     };
64   }
65 }
66 
StatefulRngUniform(Algorithm alg,xla::XlaOp key,xla::XlaOp initial_state,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)67 xla::RngOutput StatefulRngUniform(Algorithm alg, xla::XlaOp key,
68                                   xla::XlaOp initial_state,
69                                   const xla::Shape& shape, xla::XlaOp minval,
70                                   xla::XlaOp maxval) {
71   xla::PrimitiveType type = shape.element_type();
72   switch (type) {
73     case xla::F32:
74     case xla::F64:
75       return xla::UniformFloatingPointDistribution(
76           key, initial_state, BitGen(alg), minval, maxval, shape);
77     case xla::U32:
78     case xla::S32:
79     case xla::U64:
80     case xla::S64:
81       return UniformIntDistribution(key, initial_state, BitGen(alg), minval,
82                                     maxval, shape);
83     default:
84       return {key.builder()->ReportError(xla::Unimplemented(
85                   "Types other than F32, U32, S32, U64 and S64 "
86                   "are not implemented by "
87                   "StatefulRngUniform; got %s",
88                   xla::primitive_util::LowercasePrimitiveTypeName(type))),
89               initial_state};
90   }
91 }
92 
StatefulRngUniformFullInt(Algorithm alg,xla::XlaOp key,xla::XlaOp initial_state,const xla::Shape & shape)93 xla::RngOutput StatefulRngUniformFullInt(Algorithm alg, xla::XlaOp key,
94                                          xla::XlaOp initial_state,
95                                          const xla::Shape& shape) {
96   xla::PrimitiveType type = shape.element_type();
97   xla::RngOutput output = BitGen(alg)(key, initial_state, shape);
98   switch (type) {
99     case xla::U32:
100     case xla::U64:
101       return output;
102     case xla::S32:
103     case xla::S64:
104       output.value = BitcastConvertType(output.value, type);
105       return output;
106     default:
107       return {
108           key.builder()->ReportError(xla::Unimplemented(
109               "Types other than U32, S32, U64 and S64 are not implemented by "
110               "StatefulRngUniformFullInt; got: %s",
111               xla::primitive_util::LowercasePrimitiveTypeName(type))),
112           initial_state};
113   }
114 }
115 
116 using SamplerReturnType = StatusOr<xla::RngOutput>;
117 
GetMinStateSize(Algorithm alg)118 int64_t GetMinStateSize(Algorithm alg) {
119   if (alg == RNG_ALG_PHILOX) {
120     return PHILOX_MIN_STATE_SIZE;
121   }
122   return THREEFRY_MIN_STATE_SIZE;
123 }
124 
CheckStateShape(Algorithm alg,const TensorShape & shape)125 Status CheckStateShape(Algorithm alg, const TensorShape& shape) {
126   if (shape.dims() != 1) {
127     return errors::InvalidArgument(
128         "RNG state must have one and only one dimension, not ", shape.dims());
129   }
130   auto state_size = shape.dim_size(0);
131   auto min_state_size = GetMinStateSize(alg);
132   if (state_size < min_state_size) {
133     return errors::InvalidArgument("The size of the state must be at least ",
134                                    min_state_size, "; got ", state_size);
135   }
136   return OkStatus();
137 }
138 
StateAndKeyFromVariable(Algorithm alg,xla::XlaOp var)139 std::pair<xla::XlaOp, xla::XlaOp> StateAndKeyFromVariable(Algorithm alg,
140                                                           xla::XlaOp var) {
141   if (alg == RNG_ALG_THREEFRY) {
142     static constexpr int kStateSize = 1;
143     auto state = BitcastConvertType(
144         xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64);
145     auto key = BitcastConvertType(
146         xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}),
147         xla::U64);
148     return std::make_pair(state, key);
149   } else {
150     static constexpr int kStateSize = 2;
151     auto state =
152         BitcastConvertType(xla::Slice(var, {0}, {kStateSize}, {1}), xla::U64);
153     auto key = xla::Reshape(
154         BitcastConvertType(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}),
155                            xla::U64),
156         {});
157     return std::make_pair(state, key);
158   }
159 }
160 
StateAndKeyToVariable(Algorithm alg,xla::XlaOp state,xla::XlaOp key)161 xla::XlaOp StateAndKeyToVariable(Algorithm alg, xla::XlaOp state,
162                                  xla::XlaOp key) {
163   auto builder = state.builder();
164   if (alg == RNG_ALG_THREEFRY) {
165     return ConcatScalars(builder, {state, key});
166   } else {
167     return ConcatInDim(builder, {state, xla::Reshape(key, {1})}, 0);
168   }
169 }
170 
171 // A helper function containing the common part of several kernels below.
172 // Precondition: 'algorithm' and 'shape' are compile-time constants.
CompileImpl(XlaOpKernelContext * ctx,int state_input_idx,int alg_input_idx,int shape_input_idx,std::function<SamplerReturnType (Algorithm,xla::XlaOp,xla::XlaOp,TensorShape)> const & sampler)173 Status CompileImpl(
174     XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx,
175     int shape_input_idx,
176     std::function<SamplerReturnType(Algorithm, xla::XlaOp, xla::XlaOp,
177                                     TensorShape)> const& sampler) {
178   auto alg_shape = ctx->InputShape(alg_input_idx);
179   if (alg_shape.dims() != 0) {
180     return errors::InvalidArgument("algorithm must be of shape [], not ",
181                                    alg_shape.DebugString());
182   }
183   xla::Literal alg_literal;
184   TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
185   Algorithm alg = Algorithm(alg_literal.Get<int64_t>({}));
186   if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) {
187     return errors::InvalidArgument("Unsupported algorithm id: ", alg);
188   }
189 
190   xla::XlaOp var;
191   TensorShape var_shape;
192   TF_RETURN_IF_ERROR(ctx->ReadVariableInput(
193       state_input_idx, STATE_ELEMENT_DTYPE, &var_shape, &var));
194   TF_RETURN_IF_ERROR(CheckStateShape(alg, var_shape));
195   TensorShape shape;
196   TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
197   xla::XlaOp state;
198   xla::XlaOp key;
199   std::tie(state, key) = StateAndKeyFromVariable(alg, var);
200   auto status_or_value = sampler(alg, state, key, shape);
201   if (!status_or_value.ok()) {
202     return status_or_value.status();
203   }
204   xla::RngOutput value_state = std::move(status_or_value).value();
205   state = value_state.state;
206   ctx->SetOutput(0, value_state.value);
207   var = StateAndKeyToVariable(alg, state, key);
208   xla::PrimitiveType state_element_type;
209   TF_RETURN_IF_ERROR(
210       DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
211   var = BitcastConvertType(var, state_element_type);
212   TF_RETURN_IF_ERROR(
213       ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
214   return OkStatus();
215 }
216 
217 class StatefulUniformOp : public XlaOpKernel {
218  public:
StatefulUniformOp(OpKernelConstruction * ctx)219   explicit StatefulUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
220     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
221   }
222 
Compile(XlaOpKernelContext * ctx)223   void Compile(XlaOpKernelContext* ctx) override {
224     xla::XlaBuilder* builder = ctx->builder();
225     auto sampler = [builder, this](Algorithm alg, xla::XlaOp state,
226                                    xla::XlaOp key,
227                                    TensorShape shape) -> SamplerReturnType {
228       xla::Shape xla_shape;
229       DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
230       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
231       xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
232       xla::RngOutput uniform_state = StatefulRngUniform(
233           alg, key, state, xla_shape,
234           xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
235           xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
236       xla::XlaOp uniform = uniform_state.value;
237       state = uniform_state.state;
238       uniform = MaybeConvertF32ToBF16(uniform, dtype_);
239       return {{uniform, state}};
240     };
241     OP_REQUIRES_OK(ctx,
242                    CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
243                                /*shape_input_idx=*/2, sampler));
244   }
245 
246  private:
247   DataType dtype_;
248 
249   TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformOp);
250 };
251 
252 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
253 REGISTER_XLA_OP(Name("StatefulUniform")
254                     .CompileTimeConstantInput("algorithm")
255                     .CompileTimeConstantInput("shape")
256                     .TypeConstraint("dtype",
257                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
258                 StatefulUniformOp);
259 
260 class StatefulStandardNormalOp : public XlaOpKernel {
261  public:
StatefulStandardNormalOp(OpKernelConstruction * ctx)262   explicit StatefulStandardNormalOp(OpKernelConstruction* ctx)
263       : XlaOpKernel(ctx) {
264     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
265   }
266 
Compile(XlaOpKernelContext * ctx)267   void Compile(XlaOpKernelContext* ctx) override {
268     auto sampler =
269         // Needs explicit lambda return type because it fails to be inferred.
270         [this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
271                TensorShape shape) -> SamplerReturnType {
272       xla::Shape xla_shape;
273       DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
274       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
275       xla::RngOutput value_state = xla::NormalFloatingPointDistribution(
276           key, state, BitGen(alg), xla_shape);
277       xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_);
278       return {{normal, value_state.state}};
279     };
280     OP_REQUIRES_OK(ctx,
281                    CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
282                                /*shape_input_idx=*/2, sampler));
283   }
284 
285  private:
286   DataType dtype_;
287 
288   TF_DISALLOW_COPY_AND_ASSIGN(StatefulStandardNormalOp);
289 };
290 
291 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
292 REGISTER_XLA_OP(Name("StatefulStandardNormalV2")
293                     .CompileTimeConstantInput("algorithm")
294                     .CompileTimeConstantInput("shape")
295                     .TypeConstraint("dtype",
296                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
297                 StatefulStandardNormalOp);
298 
299 class StatefulTruncatedNormalOp : public XlaOpKernel {
300  public:
StatefulTruncatedNormalOp(OpKernelConstruction * ctx)301   explicit StatefulTruncatedNormalOp(OpKernelConstruction* ctx)
302       : XlaOpKernel(ctx) {
303     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
304   }
305 
Compile(XlaOpKernelContext * ctx)306   void Compile(XlaOpKernelContext* ctx) override {
307     xla::XlaBuilder* builder = ctx->builder();
308     auto sampler =
309         // Needs explicit lambda return type because it fails to be inferred.
310         [builder, this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
311                         TensorShape shape) -> SamplerReturnType {
312       xla::Shape xla_shape;
313       DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
314       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
315 
316       xla::RngOutput uniform_result = StatefulRngUniform(
317           alg, key, state, xla_shape,
318           xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
319           xla::One(builder, xla_shape.element_type()));
320       xla::XlaOp uniform = uniform_result.value;
321       state = uniform_result.state;
322       xla::XlaOp truncated_normal = TruncatedNormal(uniform);
323       truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
324       return {{truncated_normal, state}};
325     };
326     OP_REQUIRES_OK(ctx,
327                    CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
328                                /*shape_input_idx=*/2, sampler));
329   }
330 
331  private:
332   DataType dtype_;
333 
334   TF_DISALLOW_COPY_AND_ASSIGN(StatefulTruncatedNormalOp);
335 };
336 
337 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
338 REGISTER_XLA_OP(Name("StatefulTruncatedNormal")
339                     .CompileTimeConstantInput("algorithm")
340                     .CompileTimeConstantInput("shape")
341                     .TypeConstraint("dtype",
342                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
343                 StatefulTruncatedNormalOp);
344 
345 class StatefulUniformIntOp : public XlaOpKernel {
346  public:
StatefulUniformIntOp(OpKernelConstruction * ctx)347   explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
348     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
349   }
350 
Compile(XlaOpKernelContext * ctx)351   void Compile(XlaOpKernelContext* ctx) override {
352     xla::XlaOp minval = ctx->Input(3);
353     xla::XlaOp maxval = ctx->Input(4);
354     auto sample_with_threefry =
355         [minval, maxval, this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
356                                TensorShape shape) -> SamplerReturnType {
357       xla::Shape xla_shape;
358       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
359       return StatefulRngUniform(alg, key, state, xla_shape, minval, maxval);
360     };
361     OP_REQUIRES_OK(ctx,
362                    CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
363                                /*shape_input_idx=*/2, sample_with_threefry));
364   }
365 
366  private:
367   DataType dtype_;
368 
369   TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformIntOp);
370 };
371 
372 REGISTER_XLA_OP(Name("StatefulUniformInt")
373                     .CompileTimeConstantInput("algorithm")
374                     .CompileTimeConstantInput("shape")
375                     .TypeConstraint("dtype",
376                                     {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
377                 StatefulUniformIntOp);
378 
379 class StatefulUniformFullIntOp : public XlaOpKernel {
380  public:
StatefulUniformFullIntOp(OpKernelConstruction * ctx)381   explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx)
382       : XlaOpKernel(ctx) {
383     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
384   }
385 
Compile(XlaOpKernelContext * ctx)386   void Compile(XlaOpKernelContext* ctx) override {
387     auto sample_with_threefry = [this](Algorithm alg, xla::XlaOp state,
388                                        xla::XlaOp key,
389                                        TensorShape shape) -> SamplerReturnType {
390       xla::Shape xla_shape;
391       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
392       return StatefulRngUniformFullInt(alg, key, state, xla_shape);
393     };
394     OP_REQUIRES_OK(ctx,
395                    CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
396                                /*shape_input_idx=*/2, sample_with_threefry));
397   }
398 
399  private:
400   DataType dtype_;
401 
402   TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformFullIntOp);
403 };
404 
405 REGISTER_XLA_OP(Name("StatefulUniformFullInt")
406                     .CompileTimeConstantInput("algorithm")
407                     .CompileTimeConstantInput("shape")
408                     .TypeConstraint("dtype",
409                                     {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
410                 StatefulUniformFullIntOp);
411 
IncreaseCounter(Algorithm const & alg,xla::XlaOp counter,xla::XlaOp delta)412 xla::XlaOp IncreaseCounter(Algorithm const& alg, xla::XlaOp counter,
413                            xla::XlaOp delta) {
414   // Multiplying 256 to be consistent with the CPU/GPU kernels
415   delta = delta * ConstantR0WithType(delta.builder(), xla::U64, 256);
416   if (alg == RNG_ALG_PHILOX) {
417     return xla::PhiloxIncreaseCounter(counter, delta);
418   } else {
419     return counter + delta;
420   }
421 }
422 
PadRight(xla::XlaOp a,int n)423 xla::XlaOp PadRight(xla::XlaOp a, int n) {
424   return xla::Pad(a, xla::ScalarLike(a, 0),
425                   xla::MakeEdgePaddingConfig({{0, n}}));
426 }
427 
428 template <typename AlgEnumType = int64_t, bool read_old_value = false>
429 class RngSkipOp : public XlaOpKernel {
430  public:
RngSkipOp(OpKernelConstruction * ctx)431   explicit RngSkipOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
432 
Compile(XlaOpKernelContext * ctx)433   void Compile(XlaOpKernelContext* ctx) override {
434     const int state_input_idx = 0;
435     const int alg_input_idx = 1;
436     const int delta_input_idx = 2;
437     xla::XlaOp var;
438     TensorShape var_shape;
439     OP_REQUIRES_OK(ctx,
440                    ctx->ReadVariableInput(state_input_idx, STATE_ELEMENT_DTYPE,
441                                           &var_shape, &var));
442     xla::Literal alg_literal;
443     OP_REQUIRES_OK(ctx, ctx->ConstantInput(alg_input_idx, &alg_literal));
444     Algorithm alg = Algorithm(alg_literal.Get<AlgEnumType>({}));
445     OP_REQUIRES(ctx, alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX,
446                 errors::InvalidArgument("Unsupported algorithm id: ", alg));
447     OP_REQUIRES_OK(ctx, CheckStateShape(alg, var_shape));
448     if (read_old_value) {
449       auto counter_size = GetCounterSize(alg);
450       xla::XlaOp output = var;
451       if (RNG_MAX_COUNTER_SIZE > counter_size) {
452         // Because the size of `var` depends on the algorithm while we want the
453         // output to have a fixed size (to help shape inference), we fix the
454         // output size to be the maximal state size among algorithms, and right-
455         // pad it with zeros if var's size is smaller than that.
456         output = PadRight(output, RNG_MAX_COUNTER_SIZE - counter_size);
457       }
458       ctx->SetOutput(0, output);
459     }
460     xla::XlaOp counter;
461     xla::XlaOp key;
462     std::tie(counter, key) = StateAndKeyFromVariable(alg, var);
463     xla::XlaOp delta = ctx->Input(delta_input_idx);
464     delta = BitcastConvertType(delta, xla::U64);
465     auto new_counter = IncreaseCounter(alg, counter, delta);
466     var = StateAndKeyToVariable(alg, new_counter, key);
467     xla::PrimitiveType state_element_type;
468     OP_REQUIRES_OK(
469         ctx, DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
470     var = BitcastConvertType(var, state_element_type);
471     OP_REQUIRES_OK(
472         ctx, ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
473   }
474 
475  private:
476   TF_DISALLOW_COPY_AND_ASSIGN(RngSkipOp);
477 };
478 
479 REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"),
480                 RngSkipOp<>);
481 
482 using RngReadAndSkipOp = RngSkipOp<int32, true>;
483 
484 REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"),
485                 RngReadAndSkipOp);
486 
487 }  // namespace
488 }  // namespace tensorflow
489