1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17
18 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
19 #include "tensorflow/compiler/tf2xla/lib/random.h"
20 #include "tensorflow/compiler/tf2xla/shape_util.h"
21 #include "tensorflow/compiler/tf2xla/type_util.h"
22 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/compiler/xla/client/lib/constants.h"
26 #include "tensorflow/compiler/xla/client/lib/math.h"
27 #include "tensorflow/compiler/xla/client/lib/prng.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/lib/math/math_util.h"
34
35 namespace tensorflow {
36
37 namespace {
38
GetBitGeneratorForDevice(absl::string_view device_type_string)39 xla::BitGeneratorTy GetBitGeneratorForDevice(
40 absl::string_view device_type_string) {
41 // The Philox algorithm may cause performance regression on other devices.
42 // Turn on the Philox algorithm for the CPU and GPU backends only.
43 if (device_type_string == DEVICE_GPU_XLA_JIT ||
44 device_type_string == DEVICE_CPU_XLA_JIT) {
45 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
46 std::tie(state, key) = xla::ScramblePhiloxKey(key);
47 xla::XlaOp philox_state =
48 xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
49 xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX,
50 philox_state, shape);
51 return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
52 /*state=*/xla::GetTupleElement(result, 0)};
53 };
54 }
55 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
56 state = xla::ConcatScalars(key.builder(), {key, state});
57 xla::XlaOp result =
58 xla::RngBitGenerator(xla::RandomAlgorithm::RNG_DEFAULT, state, shape);
59 return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
60 /*state=*/xla::GetTupleElement(result, 0)};
61 };
62 }
63
64 } // namespace
65
MaybeConvertF32ToBF16(xla::XlaOp input,DataType dtype)66 xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
67 if (dtype == DT_BFLOAT16) {
68 xla::XlaBuilder* builder = input.builder();
69 xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) &
70 xla::ConstantR0<uint32>(builder, 0xFFFF0000);
71 return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
72 xla::BF16);
73 } else {
74 return input;
75 }
76 }
77
StatelessRngUniform(absl::string_view device_type_string,xla::XlaOp seeds,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)78 xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
79 xla::XlaOp seeds, const xla::Shape& shape,
80 xla::XlaOp minval, xla::XlaOp maxval) {
81 xla::XlaBuilder* builder = seeds.builder();
82
83 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
84 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
85 xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
86 xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
87 xla::PrimitiveType type = shape.element_type();
88 switch (type) {
89 case xla::F32:
90 case xla::F64:
91 return xla::UniformFloatingPointDistribution(
92 key, initial_state,
93 GetBitGeneratorForDevice(device_type_string), minval, maxval,
94 shape)
95 .value;
96 case xla::S32: // fall through
97 case xla::S64:
98 return UniformIntDistribution(
99 key, initial_state,
100 GetBitGeneratorForDevice(device_type_string), minval, maxval,
101 shape)
102 .value;
103 break;
104 default:
105 return builder->ReportError(xla::Unimplemented(
106 "Types other than F32, S32 and S64 are not implemented by "
107 "StatelessRngUniform; got %s",
108 xla::primitive_util::LowercasePrimitiveTypeName(type)));
109 }
110 }
111
112 namespace {
113
StatelessRngUniformFullInt(absl::string_view device_type_string,xla::XlaOp seeds,const xla::Shape & shape)114 xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
115 xla::XlaOp seeds,
116 const xla::Shape& shape) {
117 xla::XlaBuilder* builder = seeds.builder();
118
119 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
120 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
121 xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
122 xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
123 xla::PrimitiveType type = shape.element_type();
124 xla::RngOutput output =
125 GetBitGeneratorForDevice(device_type_string)(key, initial_state, shape);
126 switch (type) {
127 case xla::U32:
128 case xla::U64:
129 return output.value;
130 case xla::S32:
131 case xla::S64:
132 return BitcastConvertType(output.value, type);
133 default:
134 return builder->ReportError(xla::Unimplemented(
135 "Types other than U32, S32, U64 and S64 are not implemented by "
136 "StatelessRngUniformFullInt; got: %s",
137 xla::primitive_util::LowercasePrimitiveTypeName(type)));
138 }
139 }
140
141 class StatelessRandomUniformOp : public XlaOpKernel {
142 public:
StatelessRandomUniformOp(OpKernelConstruction * ctx)143 explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
144 : XlaOpKernel(ctx),
145 device_type_string_(ctx->device_type().type_string()) {
146 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
147 }
148
Compile(XlaOpKernelContext * ctx)149 void Compile(XlaOpKernelContext* ctx) override {
150 xla::XlaBuilder* builder = ctx->builder();
151
152 TensorShape shape;
153 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
154
155 TensorShape seed_shape = ctx->InputShape(1);
156 OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
157 errors::InvalidArgument("seed must have shape [2], not ",
158 seed_shape.DebugString()));
159 xla::XlaOp seed = ctx->Input(1);
160
161 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
162 xla::Shape xla_shape;
163 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
164 xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
165
166 xla::XlaOp uniform = StatelessRngUniform(
167 device_type_string_, seed, xla_shape,
168 xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
169 xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
170 uniform = MaybeConvertF32ToBF16(uniform, dtype_);
171 ctx->SetOutput(0, uniform);
172 }
173
174 private:
175 DataType dtype_;
176 string device_type_string_;
177
178 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
179 };
180
181 // TODO(phawkins): generalize to non-float, non-int32 seed types.
182 REGISTER_XLA_OP(Name("StatelessRandomUniform")
183 .CompileTimeConstantInput("shape")
184 .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
185 .TypeConstraint("Tseed", DT_INT32),
186 StatelessRandomUniformOp);
187
188 class StatelessRandomUniformIntOp : public XlaOpKernel {
189 public:
StatelessRandomUniformIntOp(OpKernelConstruction * ctx)190 explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
191 : XlaOpKernel(ctx),
192 device_type_string_(ctx->device_type().type_string()) {
193 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
194 }
195
Compile(XlaOpKernelContext * ctx)196 void Compile(XlaOpKernelContext* ctx) override {
197 TensorShape shape;
198 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
199
200 TensorShape seed_shape = ctx->InputShape(1);
201 OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
202 errors::InvalidArgument("seed must have shape [2], not ",
203 seed_shape.DebugString()));
204 TensorShape minval_shape = ctx->InputShape(2);
205 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
206 errors::InvalidArgument("minval must be scalar, got shape ",
207 minval_shape.DebugString()));
208 TensorShape maxval_shape = ctx->InputShape(3);
209 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
210 errors::InvalidArgument("minval must be scalar, got shape ",
211 maxval_shape.DebugString()));
212
213 xla::XlaOp seed = ctx->Input(1);
214 xla::XlaOp minval = ctx->Input(2);
215 xla::XlaOp maxval = ctx->Input(3);
216
217 xla::Shape xla_shape;
218 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
219 xla::XlaOp uniform = StatelessRngUniform(device_type_string_, seed,
220 xla_shape, minval, maxval);
221
222 ctx->SetOutput(0, uniform);
223 }
224
225 private:
226 DataType dtype_;
227 string device_type_string_;
228
229 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
230 };
231
232 // TODO(phawkins): generalize to non-int32 seed types.
233 REGISTER_XLA_OP(Name("StatelessRandomUniformInt")
234 .CompileTimeConstantInput("shape")
235 .TypeConstraint("dtype", {DT_INT32, DT_INT64})
236 .TypeConstraint("Tseed", DT_INT32),
237 StatelessRandomUniformIntOp);
238
239 class StatelessRandomUniformFullIntOp : public XlaOpKernel {
240 public:
StatelessRandomUniformFullIntOp(OpKernelConstruction * ctx)241 explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
242 : XlaOpKernel(ctx),
243 device_type_string_(ctx->device_type().type_string()) {
244 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
245 }
246
Compile(XlaOpKernelContext * ctx)247 void Compile(XlaOpKernelContext* ctx) override {
248 TensorShape shape;
249 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
250
251 TensorShape seed_shape = ctx->InputShape(1);
252 OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
253 errors::InvalidArgument("seed must have shape [2], not ",
254 seed_shape.DebugString()));
255
256 xla::XlaOp seed = ctx->Input(1);
257
258 xla::Shape xla_shape;
259 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
260 xla::XlaOp uniform =
261 StatelessRngUniformFullInt(device_type_string_, seed, xla_shape);
262
263 ctx->SetOutput(0, uniform);
264 }
265
266 private:
267 DataType dtype_;
268 string device_type_string_;
269
270 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
271 };
272
273 // TODO(phawkins): generalize to non-int32 seed types.
274 REGISTER_XLA_OP(Name("StatelessRandomUniformFullInt")
275 .CompileTimeConstantInput("shape")
276 .TypeConstraint("dtype", {DT_INT32, DT_INT64})
277 .TypeConstraint("Tseed", DT_INT32),
278 StatelessRandomUniformFullIntOp);
279
280 class StatelessRandomNormalOp : public XlaOpKernel {
281 public:
StatelessRandomNormalOp(OpKernelConstruction * ctx)282 explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
283 : XlaOpKernel(ctx),
284 device_type_string_(ctx->device_type().type_string()) {
285 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
286 }
287
Compile(XlaOpKernelContext * ctx)288 void Compile(XlaOpKernelContext* ctx) override {
289 TensorShape shape;
290 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
291
292 TensorShape seed_shape = ctx->InputShape(1);
293 OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
294 errors::InvalidArgument("seed must have shape [2], not ",
295 seed_shape.DebugString()));
296 xla::XlaOp seed = ctx->Input(1);
297 xla::Shape xla_shape;
298
299 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
300 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
301
302 xla::XlaBuilder* builder = seed.builder();
303 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
304 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
305 xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
306
307 xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
308 xla::XlaOp normal =
309 xla::NormalFloatingPointDistribution(
310 key, initial_state, GetBitGeneratorForDevice(device_type_string_),
311 xla_shape)
312 .value;
313 normal = MaybeConvertF32ToBF16(normal, dtype_);
314 ctx->SetOutput(0, normal);
315 }
316
317 private:
318 DataType dtype_;
319 string device_type_string_;
320
321 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
322 };
323
324 // TODO(phawkins): generalize to non-float, non-int32 seed types.
325 REGISTER_XLA_OP(Name("StatelessRandomNormal")
326 .CompileTimeConstantInput("shape")
327 .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
328 .TypeConstraint("Tseed", DT_INT32),
329 StatelessRandomNormalOp);
330
331 class StatelessTruncatedNormalOp : public XlaOpKernel {
332 public:
StatelessTruncatedNormalOp(OpKernelConstruction * ctx)333 explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
334 : XlaOpKernel(ctx),
335 device_type_string_(ctx->device_type().type_string()) {
336 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
337 }
338
Compile(XlaOpKernelContext * ctx)339 void Compile(XlaOpKernelContext* ctx) override {
340 TensorShape shape;
341 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
342
343 TensorShape seed_shape = ctx->InputShape(1);
344 OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
345 errors::InvalidArgument("seed must have shape [2], not ",
346 seed_shape.DebugString()));
347 xla::XlaOp seed = ctx->Input(1);
348 xla::XlaBuilder* builder = ctx->builder();
349
350 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
351 xla::Shape xla_shape;
352 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
353 xla::XlaOp uniform = StatelessRngUniform(
354 device_type_string_, seed, xla_shape,
355 xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
356 xla::One(builder, xla_shape.element_type()));
357 xla::XlaOp truncated_normal = TruncatedNormal(uniform);
358 truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
359 ctx->SetOutput(0, truncated_normal);
360 }
361
362 private:
363 DataType dtype_;
364 string device_type_string_;
365
366 TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
367 };
368
369 REGISTER_XLA_OP(Name("StatelessTruncatedNormal")
370 .CompileTimeConstantInput("shape")
371 .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
372 .TypeConstraint("Tseed", DT_INT32),
373 StatelessTruncatedNormalOp);
374
375 } // namespace
376 } // namespace tensorflow
377