1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17
18 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
19 #include "tensorflow/compiler/tf2xla/lib/broadcast.h"
20 #include "tensorflow/compiler/tf2xla/lib/random.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/constants.h"
27 #include "tensorflow/compiler/xla/client/lib/math.h"
28 #include "tensorflow/compiler/xla/client/lib/prng.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/lib/math/math_util.h"
35
36 namespace tensorflow {
37
38 namespace {
39
GetBitGeneratorForDevice(absl::string_view device_type_string)40 xla::BitGeneratorTy GetBitGeneratorForDevice(
41 absl::string_view device_type_string) {
42 // The Philox algorithm may cause performance regression on other devices.
43 // Turn on the Philox algorithm for the CPU and GPU backends only.
44 if (device_type_string == DEVICE_GPU_XLA_JIT ||
45 device_type_string == DEVICE_CPU_XLA_JIT) {
46 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
47 std::tie(state, key) = xla::ScramblePhiloxKey(key);
48 xla::XlaOp philox_state =
49 xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
50 xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX,
51 philox_state, shape);
52 return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
53 /*state=*/xla::GetTupleElement(result, 0)};
54 };
55 }
56 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
57 state = xla::ConcatScalars(key.builder(), {key, state});
58 xla::XlaOp result =
59 xla::RngBitGenerator(xla::RandomAlgorithm::RNG_DEFAULT, state, shape);
60 return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
61 /*state=*/xla::GetTupleElement(result, 0)};
62 };
63 }
64
65 } // namespace
66
MaybeConvertF32ToBF16(xla::XlaOp input,DataType dtype)67 xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
68 if (dtype == DT_BFLOAT16) {
69 xla::XlaBuilder* builder = input.builder();
70 xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) &
71 xla::ConstantR0<uint32>(builder, 0xFFFF0000);
72 return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
73 xla::BF16);
74 } else {
75 return input;
76 }
77 }
78
StatelessRngUniform(absl::string_view device_type_string,xla::XlaOp seeds,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)79 xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
80 xla::XlaOp seeds, const xla::Shape& shape,
81 xla::XlaOp minval, xla::XlaOp maxval) {
82 xla::XlaBuilder* builder = seeds.builder();
83
84 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
85 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
86 xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
87 xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
88 xla::PrimitiveType type = shape.element_type();
89 switch (type) {
90 case xla::F32:
91 case xla::F64:
92 return xla::UniformFloatingPointDistribution(
93 key, initial_state,
94 GetBitGeneratorForDevice(device_type_string), minval, maxval,
95 shape)
96 .value;
97 case xla::S32: // fall through
98 case xla::S64:
99 return UniformIntDistribution(
100 key, initial_state,
101 GetBitGeneratorForDevice(device_type_string), minval, maxval,
102 shape)
103 .value;
104 break;
105 default:
106 return builder->ReportError(xla::Unimplemented(
107 "Types other than F32, S32 and S64 are not implemented by "
108 "StatelessRngUniform; got %s",
109 xla::primitive_util::LowercasePrimitiveTypeName(type)));
110 }
111 }
112
113 namespace {
114
StatelessRngUniformFullInt(absl::string_view device_type_string,xla::XlaOp seeds,const xla::Shape & shape)115 xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
116 xla::XlaOp seeds,
117 const xla::Shape& shape) {
118 xla::XlaBuilder* builder = seeds.builder();
119
120 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
121 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
122 xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
123 xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
124 xla::PrimitiveType type = shape.element_type();
125 xla::RngOutput output =
126 GetBitGeneratorForDevice(device_type_string)(key, initial_state, shape);
127 switch (type) {
128 case xla::U32:
129 case xla::U64:
130 return output.value;
131 case xla::S32:
132 case xla::S64:
133 return BitcastConvertType(output.value, type);
134 default:
135 return builder->ReportError(xla::Unimplemented(
136 "Types other than U32, S32, U64 and S64 are not implemented by "
137 "StatelessRngUniformFullInt; got: %s",
138 xla::primitive_util::LowercasePrimitiveTypeName(type)));
139 }
140 }
141
142 class StatelessRandomUniformOp : public XlaOpKernel {
143 public:
StatelessRandomUniformOp(OpKernelConstruction * ctx)144 explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
145 : XlaOpKernel(ctx),
146 device_type_string_(ctx->device_type().type_string()) {
147 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
148 }
149
Compile(XlaOpKernelContext * ctx)150 void Compile(XlaOpKernelContext* ctx) override {
151 xla::XlaBuilder* builder = ctx->builder();
152
153 TensorShape shape;
154 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
155
156 TensorShape seed_shape = ctx->InputShape(1);
157 OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
158 errors::InvalidArgument("seed must have shape [2], not ",
159 seed_shape.DebugString()));
160 xla::XlaOp seed = ctx->Input(1);
161
162 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
163 xla::Shape xla_shape;
164 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
165 xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
166
167 xla::XlaOp uniform = StatelessRngUniform(
168 device_type_string_, seed, xla_shape,
169 xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
170 xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
171 uniform = MaybeConvertF32ToBF16(uniform, dtype_);
172 ctx->SetOutput(0, uniform);
173 }
174
175 private:
176 DataType dtype_;
177 string device_type_string_;
178
179 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
180 };
181
182 // TODO(phawkins): generalize to non-float, non-int32 seed types.
183 REGISTER_XLA_OP(Name("StatelessRandomUniform")
184 .CompileTimeConstantInput("shape")
185 .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
186 .TypeConstraint("Tseed", DT_INT32),
187 StatelessRandomUniformOp);
188
189 class StatelessRandomUniformIntOp : public XlaOpKernel {
190 public:
StatelessRandomUniformIntOp(OpKernelConstruction * ctx)191 explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
192 : XlaOpKernel(ctx),
193 device_type_string_(ctx->device_type().type_string()) {
194 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
195 }
196
Compile(XlaOpKernelContext * ctx)197 void Compile(XlaOpKernelContext* ctx) override {
198 TensorShape shape;
199 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
200
201 TensorShape seed_shape = ctx->InputShape(1);
202 OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
203 errors::InvalidArgument("seed must have shape [2], not ",
204 seed_shape.DebugString()));
205 TensorShape minval_shape = ctx->InputShape(2);
206 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
207 errors::InvalidArgument("minval must be scalar, got shape ",
208 minval_shape.DebugString()));
209 TensorShape maxval_shape = ctx->InputShape(3);
210 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
211 errors::InvalidArgument("minval must be scalar, got shape ",
212 maxval_shape.DebugString()));
213
214 xla::XlaOp seed = ctx->Input(1);
215 xla::XlaOp minval = ctx->Input(2);
216 xla::XlaOp maxval = ctx->Input(3);
217
218 xla::Shape xla_shape;
219 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
220 xla::XlaOp uniform = StatelessRngUniform(device_type_string_, seed,
221 xla_shape, minval, maxval);
222
223 ctx->SetOutput(0, uniform);
224 }
225
226 private:
227 DataType dtype_;
228 string device_type_string_;
229
230 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
231 };
232
233 // TODO(phawkins): generalize to non-int32 seed types.
234 REGISTER_XLA_OP(Name("StatelessRandomUniformInt")
235 .CompileTimeConstantInput("shape")
236 .TypeConstraint("dtype", {DT_INT32, DT_INT64})
237 .TypeConstraint("Tseed", DT_INT32),
238 StatelessRandomUniformIntOp);
239
240 class StatelessRandomUniformFullIntOp : public XlaOpKernel {
241 public:
StatelessRandomUniformFullIntOp(OpKernelConstruction * ctx)242 explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
243 : XlaOpKernel(ctx),
244 device_type_string_(ctx->device_type().type_string()) {
245 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
246 }
247
Compile(XlaOpKernelContext * ctx)248 void Compile(XlaOpKernelContext* ctx) override {
249 TensorShape shape;
250 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
251
252 TensorShape seed_shape = ctx->InputShape(1);
253 OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
254 errors::InvalidArgument("seed must have shape [2], not ",
255 seed_shape.DebugString()));
256
257 xla::XlaOp seed = ctx->Input(1);
258
259 xla::Shape xla_shape;
260 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
261 xla::XlaOp uniform =
262 StatelessRngUniformFullInt(device_type_string_, seed, xla_shape);
263
264 ctx->SetOutput(0, uniform);
265 }
266
267 private:
268 DataType dtype_;
269 string device_type_string_;
270
271 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
272 };
273
274 // TODO(phawkins): generalize to non-int32 seed types.
275 REGISTER_XLA_OP(Name("StatelessRandomUniformFullInt")
276 .CompileTimeConstantInput("shape")
277 .TypeConstraint("dtype", {DT_INT32, DT_INT64})
278 .TypeConstraint("Tseed", DT_INT32),
279 StatelessRandomUniformFullIntOp);
280
281 class StatelessRandomNormalOp : public XlaOpKernel {
282 public:
StatelessRandomNormalOp(OpKernelConstruction * ctx)283 explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
284 : XlaOpKernel(ctx),
285 device_type_string_(ctx->device_type().type_string()) {
286 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
287 }
288
Compile(XlaOpKernelContext * ctx)289 void Compile(XlaOpKernelContext* ctx) override {
290 TensorShape shape;
291 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
292
293 TensorShape seed_shape = ctx->InputShape(1);
294 OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
295 errors::InvalidArgument("seed must have shape [2], not ",
296 seed_shape.DebugString()));
297 xla::XlaOp seed = ctx->Input(1);
298 xla::Shape xla_shape;
299
300 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
301 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
302
303 xla::XlaBuilder* builder = seed.builder();
304 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
305 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
306 xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
307
308 xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
309 xla::XlaOp normal =
310 xla::NormalFloatingPointDistribution(
311 key, initial_state, GetBitGeneratorForDevice(device_type_string_),
312 xla_shape)
313 .value;
314 normal = MaybeConvertF32ToBF16(normal, dtype_);
315 ctx->SetOutput(0, normal);
316 }
317
318 private:
319 DataType dtype_;
320 string device_type_string_;
321
322 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
323 };
324
325 // TODO(phawkins): generalize to non-float, non-int32 seed types.
326 REGISTER_XLA_OP(Name("StatelessRandomNormal")
327 .CompileTimeConstantInput("shape")
328 .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
329 .TypeConstraint("Tseed", DT_INT32),
330 StatelessRandomNormalOp);
331
332 class StatelessTruncatedNormalOp : public XlaOpKernel {
333 public:
StatelessTruncatedNormalOp(OpKernelConstruction * ctx)334 explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
335 : XlaOpKernel(ctx),
336 device_type_string_(ctx->device_type().type_string()) {
337 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
338 }
339
Compile(XlaOpKernelContext * ctx)340 void Compile(XlaOpKernelContext* ctx) override {
341 TensorShape shape;
342 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
343
344 TensorShape seed_shape = ctx->InputShape(1);
345 OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
346 errors::InvalidArgument("seed must have shape [2], not ",
347 seed_shape.DebugString()));
348 xla::XlaOp seed = ctx->Input(1);
349 xla::XlaBuilder* builder = ctx->builder();
350
351 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
352 xla::Shape xla_shape;
353 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
354 xla::XlaOp uniform = StatelessRngUniform(
355 device_type_string_, seed, xla_shape,
356 xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
357 xla::One(builder, xla_shape.element_type()));
358 xla::XlaOp truncated_normal = TruncatedNormal(uniform);
359 truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
360 ctx->SetOutput(0, truncated_normal);
361 }
362
363 private:
364 DataType dtype_;
365 string device_type_string_;
366
367 TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
368 };
369
370 REGISTER_XLA_OP(Name("StatelessTruncatedNormal")
371 .CompileTimeConstantInput("shape")
372 .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
373 .TypeConstraint("Tseed", DT_INT32),
374 StatelessTruncatedNormalOp);
375
376 class StatelessParameterizedTruncatedNormalOp : public XlaOpKernel {
377 public:
StatelessParameterizedTruncatedNormalOp(OpKernelConstruction * ctx)378 explicit StatelessParameterizedTruncatedNormalOp(OpKernelConstruction* ctx)
379 : XlaOpKernel(ctx),
380 device_type_string_(ctx->device_type().type_string()) {
381 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
382 }
383
Compile(XlaOpKernelContext * ctx)384 void Compile(XlaOpKernelContext* ctx) override {
385 TensorShape shape;
386 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
387
388 TensorShape seed_shape = ctx->InputShape(1);
389 OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
390 errors::InvalidArgument("seed must have shape [2], not ",
391 seed_shape.DebugString()));
392 xla::XlaOp seed = ctx->Input(1);
393 xla::XlaBuilder* builder = ctx->builder();
394
395 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
396
397 xla::Shape xla_shape;
398 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
399
400 auto bcasted_means = BroadcastTo(ctx->Input(2), shape.dim_sizes());
401 OP_REQUIRES_OK(ctx, bcasted_means.status());
402 auto means = bcasted_means.ValueOrDie();
403
404 auto bcasted_stddevs = BroadcastTo(ctx->Input(3), shape.dim_sizes());
405 OP_REQUIRES_OK(ctx, bcasted_stddevs.status());
406 auto stddevs = bcasted_stddevs.ValueOrDie();
407
408 auto bcasted_minvals = BroadcastTo(ctx->Input(4), shape.dim_sizes());
409 OP_REQUIRES_OK(ctx, bcasted_minvals.status());
410 auto minvals = bcasted_minvals.ValueOrDie();
411
412 auto bcasted_maxvals = BroadcastTo(ctx->Input(5), shape.dim_sizes());
413 OP_REQUIRES_OK(ctx, bcasted_maxvals.status());
414 auto maxvals = bcasted_maxvals.ValueOrDie();
415
416 xla::XlaOp uniform = StatelessRngUniform(
417 device_type_string_, seed, xla_shape,
418 xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
419 xla::One(builder, xla_shape.element_type()));
420 xla::XlaOp truncated_normal =
421 ParameterizedTruncatedNormal(uniform, means, stddevs, minvals, maxvals);
422 truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
423 ctx->SetOutput(0, truncated_normal);
424 }
425
426 private:
427 DataType dtype_;
428 string device_type_string_;
429
430 TF_DISALLOW_COPY_AND_ASSIGN(StatelessParameterizedTruncatedNormalOp);
431 };
432
433 REGISTER_XLA_OP(Name("StatelessParameterizedTruncatedNormal")
434 .CompileTimeConstantInput("shape")
435 .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
436 .TypeConstraint("Tseed", DT_INT32),
437 StatelessParameterizedTruncatedNormalOp);
438
439 } // namespace
440 } // namespace tensorflow
441