• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 
18 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
19 #include "tensorflow/compiler/tf2xla/lib/broadcast.h"
20 #include "tensorflow/compiler/tf2xla/lib/random.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/constants.h"
27 #include "tensorflow/compiler/xla/client/lib/math.h"
28 #include "tensorflow/compiler/xla/client/lib/prng.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/lib/math/math_util.h"
35 
36 namespace tensorflow {
37 
38 namespace {
39 
GetBitGeneratorForDevice(absl::string_view device_type_string)40 xla::BitGeneratorTy GetBitGeneratorForDevice(
41     absl::string_view device_type_string) {
42   // The Philox algorithm may cause performance regression on other devices.
43   // Turn on the Philox algorithm for the CPU and GPU backends only.
44   if (device_type_string == DEVICE_GPU_XLA_JIT ||
45       device_type_string == DEVICE_CPU_XLA_JIT) {
46     return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
47       std::tie(state, key) = xla::ScramblePhiloxKey(key);
48       xla::XlaOp philox_state =
49           xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
50       xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX,
51                                                philox_state, shape);
52       return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
53                             /*state=*/xla::GetTupleElement(result, 0)};
54     };
55   }
56   return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
57     state = xla::ConcatScalars(key.builder(), {key, state});
58     xla::XlaOp result =
59         xla::RngBitGenerator(xla::RandomAlgorithm::RNG_DEFAULT, state, shape);
60     return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
61                           /*state=*/xla::GetTupleElement(result, 0)};
62   };
63 }
64 
65 }  // namespace
66 
MaybeConvertF32ToBF16(xla::XlaOp input,DataType dtype)67 xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
68   if (dtype == DT_BFLOAT16) {
69     xla::XlaBuilder* builder = input.builder();
70     xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) &
71                         xla::ConstantR0<uint32>(builder, 0xFFFF0000);
72     return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
73                                    xla::BF16);
74   } else {
75     return input;
76   }
77 }
78 
StatelessRngUniform(absl::string_view device_type_string,xla::XlaOp seeds,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)79 xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
80                                xla::XlaOp seeds, const xla::Shape& shape,
81                                xla::XlaOp minval, xla::XlaOp maxval) {
82   xla::XlaBuilder* builder = seeds.builder();
83 
84   xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
85   xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
86   xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
87   xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
88   xla::PrimitiveType type = shape.element_type();
89   switch (type) {
90     case xla::F32:
91     case xla::F64:
92       return xla::UniformFloatingPointDistribution(
93                  key, initial_state,
94                  GetBitGeneratorForDevice(device_type_string), minval, maxval,
95                  shape)
96           .value;
97     case xla::S32:  // fall through
98     case xla::S64:
99       return UniformIntDistribution(
100                  key, initial_state,
101                  GetBitGeneratorForDevice(device_type_string), minval, maxval,
102                  shape)
103           .value;
104       break;
105     default:
106       return builder->ReportError(xla::Unimplemented(
107           "Types other than F32, S32 and S64 are not implemented by "
108           "StatelessRngUniform; got %s",
109           xla::primitive_util::LowercasePrimitiveTypeName(type)));
110   }
111 }
112 
113 namespace {
114 
StatelessRngUniformFullInt(absl::string_view device_type_string,xla::XlaOp seeds,const xla::Shape & shape)115 xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
116                                       xla::XlaOp seeds,
117                                       const xla::Shape& shape) {
118   xla::XlaBuilder* builder = seeds.builder();
119 
120   xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
121   xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
122   xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
123   xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
124   xla::PrimitiveType type = shape.element_type();
125   xla::RngOutput output =
126       GetBitGeneratorForDevice(device_type_string)(key, initial_state, shape);
127   switch (type) {
128     case xla::U32:
129     case xla::U64:
130       return output.value;
131     case xla::S32:
132     case xla::S64:
133       return BitcastConvertType(output.value, type);
134     default:
135       return builder->ReportError(xla::Unimplemented(
136           "Types other than U32, S32, U64 and S64 are not implemented by "
137           "StatelessRngUniformFullInt; got: %s",
138           xla::primitive_util::LowercasePrimitiveTypeName(type)));
139   }
140 }
141 
142 class StatelessRandomUniformOp : public XlaOpKernel {
143  public:
StatelessRandomUniformOp(OpKernelConstruction * ctx)144   explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
145       : XlaOpKernel(ctx),
146         device_type_string_(ctx->device_type().type_string()) {
147     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
148   }
149 
Compile(XlaOpKernelContext * ctx)150   void Compile(XlaOpKernelContext* ctx) override {
151     xla::XlaBuilder* builder = ctx->builder();
152 
153     TensorShape shape;
154     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
155 
156     TensorShape seed_shape = ctx->InputShape(1);
157     OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
158                 errors::InvalidArgument("seed must have shape [2], not ",
159                                         seed_shape.DebugString()));
160     xla::XlaOp seed = ctx->Input(1);
161 
162     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
163     xla::Shape xla_shape;
164     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
165     xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
166 
167     xla::XlaOp uniform = StatelessRngUniform(
168         device_type_string_, seed, xla_shape,
169         xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
170         xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
171     uniform = MaybeConvertF32ToBF16(uniform, dtype_);
172     ctx->SetOutput(0, uniform);
173   }
174 
175  private:
176   DataType dtype_;
177   string device_type_string_;
178 
179   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
180 };
181 
182 // TODO(phawkins): generalize to non-float, non-int32 seed types.
183 REGISTER_XLA_OP(Name("StatelessRandomUniform")
184                     .CompileTimeConstantInput("shape")
185                     .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
186                     .TypeConstraint("Tseed", DT_INT32),
187                 StatelessRandomUniformOp);
188 
189 class StatelessRandomUniformIntOp : public XlaOpKernel {
190  public:
StatelessRandomUniformIntOp(OpKernelConstruction * ctx)191   explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
192       : XlaOpKernel(ctx),
193         device_type_string_(ctx->device_type().type_string()) {
194     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
195   }
196 
Compile(XlaOpKernelContext * ctx)197   void Compile(XlaOpKernelContext* ctx) override {
198     TensorShape shape;
199     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
200 
201     TensorShape seed_shape = ctx->InputShape(1);
202     OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
203                 errors::InvalidArgument("seed must have shape [2], not ",
204                                         seed_shape.DebugString()));
205     TensorShape minval_shape = ctx->InputShape(2);
206     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
207                 errors::InvalidArgument("minval must be scalar, got shape ",
208                                         minval_shape.DebugString()));
209     TensorShape maxval_shape = ctx->InputShape(3);
210     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
211                 errors::InvalidArgument("minval must be scalar, got shape ",
212                                         maxval_shape.DebugString()));
213 
214     xla::XlaOp seed = ctx->Input(1);
215     xla::XlaOp minval = ctx->Input(2);
216     xla::XlaOp maxval = ctx->Input(3);
217 
218     xla::Shape xla_shape;
219     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
220     xla::XlaOp uniform = StatelessRngUniform(device_type_string_, seed,
221                                              xla_shape, minval, maxval);
222 
223     ctx->SetOutput(0, uniform);
224   }
225 
226  private:
227   DataType dtype_;
228   string device_type_string_;
229 
230   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
231 };
232 
233 // TODO(phawkins): generalize to non-int32 seed types.
234 REGISTER_XLA_OP(Name("StatelessRandomUniformInt")
235                     .CompileTimeConstantInput("shape")
236                     .TypeConstraint("dtype", {DT_INT32, DT_INT64})
237                     .TypeConstraint("Tseed", DT_INT32),
238                 StatelessRandomUniformIntOp);
239 
240 class StatelessRandomUniformFullIntOp : public XlaOpKernel {
241  public:
StatelessRandomUniformFullIntOp(OpKernelConstruction * ctx)242   explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
243       : XlaOpKernel(ctx),
244         device_type_string_(ctx->device_type().type_string()) {
245     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
246   }
247 
Compile(XlaOpKernelContext * ctx)248   void Compile(XlaOpKernelContext* ctx) override {
249     TensorShape shape;
250     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
251 
252     TensorShape seed_shape = ctx->InputShape(1);
253     OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
254                 errors::InvalidArgument("seed must have shape [2], not ",
255                                         seed_shape.DebugString()));
256 
257     xla::XlaOp seed = ctx->Input(1);
258 
259     xla::Shape xla_shape;
260     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
261     xla::XlaOp uniform =
262         StatelessRngUniformFullInt(device_type_string_, seed, xla_shape);
263 
264     ctx->SetOutput(0, uniform);
265   }
266 
267  private:
268   DataType dtype_;
269   string device_type_string_;
270 
271   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
272 };
273 
274 // TODO(phawkins): generalize to non-int32 seed types.
275 REGISTER_XLA_OP(Name("StatelessRandomUniformFullInt")
276                     .CompileTimeConstantInput("shape")
277                     .TypeConstraint("dtype", {DT_INT32, DT_INT64})
278                     .TypeConstraint("Tseed", DT_INT32),
279                 StatelessRandomUniformFullIntOp);
280 
281 class StatelessRandomNormalOp : public XlaOpKernel {
282  public:
StatelessRandomNormalOp(OpKernelConstruction * ctx)283   explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
284       : XlaOpKernel(ctx),
285         device_type_string_(ctx->device_type().type_string()) {
286     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
287   }
288 
Compile(XlaOpKernelContext * ctx)289   void Compile(XlaOpKernelContext* ctx) override {
290     TensorShape shape;
291     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
292 
293     TensorShape seed_shape = ctx->InputShape(1);
294     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
295                 errors::InvalidArgument("seed must have shape [2], not ",
296                                         seed_shape.DebugString()));
297     xla::XlaOp seed = ctx->Input(1);
298     xla::Shape xla_shape;
299 
300     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
301     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
302 
303     xla::XlaBuilder* builder = seed.builder();
304     xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
305     xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
306     xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
307 
308     xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
309     xla::XlaOp normal =
310         xla::NormalFloatingPointDistribution(
311             key, initial_state, GetBitGeneratorForDevice(device_type_string_),
312             xla_shape)
313             .value;
314     normal = MaybeConvertF32ToBF16(normal, dtype_);
315     ctx->SetOutput(0, normal);
316   }
317 
318  private:
319   DataType dtype_;
320   string device_type_string_;
321 
322   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
323 };
324 
325 // TODO(phawkins): generalize to non-float, non-int32 seed types.
326 REGISTER_XLA_OP(Name("StatelessRandomNormal")
327                     .CompileTimeConstantInput("shape")
328                     .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
329                     .TypeConstraint("Tseed", DT_INT32),
330                 StatelessRandomNormalOp);
331 
332 class StatelessTruncatedNormalOp : public XlaOpKernel {
333  public:
StatelessTruncatedNormalOp(OpKernelConstruction * ctx)334   explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
335       : XlaOpKernel(ctx),
336         device_type_string_(ctx->device_type().type_string()) {
337     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
338   }
339 
Compile(XlaOpKernelContext * ctx)340   void Compile(XlaOpKernelContext* ctx) override {
341     TensorShape shape;
342     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
343 
344     TensorShape seed_shape = ctx->InputShape(1);
345     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
346                 errors::InvalidArgument("seed must have shape [2], not ",
347                                         seed_shape.DebugString()));
348     xla::XlaOp seed = ctx->Input(1);
349     xla::XlaBuilder* builder = ctx->builder();
350 
351     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
352     xla::Shape xla_shape;
353     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
354     xla::XlaOp uniform = StatelessRngUniform(
355         device_type_string_, seed, xla_shape,
356         xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
357         xla::One(builder, xla_shape.element_type()));
358     xla::XlaOp truncated_normal = TruncatedNormal(uniform);
359     truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
360     ctx->SetOutput(0, truncated_normal);
361   }
362 
363  private:
364   DataType dtype_;
365   string device_type_string_;
366 
367   TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
368 };
369 
370 REGISTER_XLA_OP(Name("StatelessTruncatedNormal")
371                     .CompileTimeConstantInput("shape")
372                     .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
373                     .TypeConstraint("Tseed", DT_INT32),
374                 StatelessTruncatedNormalOp);
375 
376 class StatelessParameterizedTruncatedNormalOp : public XlaOpKernel {
377  public:
StatelessParameterizedTruncatedNormalOp(OpKernelConstruction * ctx)378   explicit StatelessParameterizedTruncatedNormalOp(OpKernelConstruction* ctx)
379       : XlaOpKernel(ctx),
380         device_type_string_(ctx->device_type().type_string()) {
381     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
382   }
383 
Compile(XlaOpKernelContext * ctx)384   void Compile(XlaOpKernelContext* ctx) override {
385     TensorShape shape;
386     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
387 
388     TensorShape seed_shape = ctx->InputShape(1);
389     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
390                 errors::InvalidArgument("seed must have shape [2], not ",
391                                         seed_shape.DebugString()));
392     xla::XlaOp seed = ctx->Input(1);
393     xla::XlaBuilder* builder = ctx->builder();
394 
395     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
396 
397     xla::Shape xla_shape;
398     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
399 
400     auto bcasted_means = BroadcastTo(ctx->Input(2), shape.dim_sizes());
401     OP_REQUIRES_OK(ctx, bcasted_means.status());
402     auto means = bcasted_means.ValueOrDie();
403 
404     auto bcasted_stddevs = BroadcastTo(ctx->Input(3), shape.dim_sizes());
405     OP_REQUIRES_OK(ctx, bcasted_stddevs.status());
406     auto stddevs = bcasted_stddevs.ValueOrDie();
407 
408     auto bcasted_minvals = BroadcastTo(ctx->Input(4), shape.dim_sizes());
409     OP_REQUIRES_OK(ctx, bcasted_minvals.status());
410     auto minvals = bcasted_minvals.ValueOrDie();
411 
412     auto bcasted_maxvals = BroadcastTo(ctx->Input(5), shape.dim_sizes());
413     OP_REQUIRES_OK(ctx, bcasted_maxvals.status());
414     auto maxvals = bcasted_maxvals.ValueOrDie();
415 
416     xla::XlaOp uniform = StatelessRngUniform(
417         device_type_string_, seed, xla_shape,
418         xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
419         xla::One(builder, xla_shape.element_type()));
420     xla::XlaOp truncated_normal =
421         ParameterizedTruncatedNormal(uniform, means, stddevs, minvals, maxvals);
422     truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
423     ctx->SetOutput(0, truncated_normal);
424   }
425 
426  private:
427   DataType dtype_;
428   string device_type_string_;
429 
430   TF_DISALLOW_COPY_AND_ASSIGN(StatelessParameterizedTruncatedNormalOp);
431 };
432 
433 REGISTER_XLA_OP(Name("StatelessParameterizedTruncatedNormal")
434                     .CompileTimeConstantInput("shape")
435                     .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
436                     .TypeConstraint("Tseed", DT_INT32),
437                 StatelessParameterizedTruncatedNormalOp);
438 
439 }  // namespace
440 }  // namespace tensorflow
441