1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 
18 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
19 #include "tensorflow/compiler/tf2xla/lib/random.h"
20 #include "tensorflow/compiler/tf2xla/shape_util.h"
21 #include "tensorflow/compiler/tf2xla/type_util.h"
22 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/compiler/xla/client/lib/constants.h"
26 #include "tensorflow/compiler/xla/client/lib/math.h"
27 #include "tensorflow/compiler/xla/client/lib/prng.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/lib/math/math_util.h"
34 
35 namespace tensorflow {
36 
37 namespace {
38 
GetBitGeneratorForDevice(absl::string_view device_type_string)39 xla::BitGeneratorTy GetBitGeneratorForDevice(
40     absl::string_view device_type_string) {
41   // The Philox algorithm may cause performance regression on other devices.
42   // Turn on the Philox algorithm for the CPU and GPU backends only.
43   if (device_type_string == DEVICE_GPU_XLA_JIT ||
44       device_type_string == DEVICE_CPU_XLA_JIT) {
45     return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
46       std::tie(state, key) = xla::ScramblePhiloxKey(key);
47       xla::XlaOp philox_state =
48           xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
49       xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX,
50                                                philox_state, shape);
51       return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
52                             /*state=*/xla::GetTupleElement(result, 0)};
53     };
54   }
55   return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
56     state = xla::ConcatScalars(key.builder(), {key, state});
57     xla::XlaOp result =
58         xla::RngBitGenerator(xla::RandomAlgorithm::RNG_DEFAULT, state, shape);
59     return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
60                           /*state=*/xla::GetTupleElement(result, 0)};
61   };
62 }
63 
64 }  // namespace
65 
MaybeConvertF32ToBF16(xla::XlaOp input,DataType dtype)66 xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
67   if (dtype == DT_BFLOAT16) {
68     xla::XlaBuilder* builder = input.builder();
69     xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) &
70                         xla::ConstantR0<uint32>(builder, 0xFFFF0000);
71     return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
72                                    xla::BF16);
73   } else {
74     return input;
75   }
76 }
77 
StatelessRngUniform(absl::string_view device_type_string,xla::XlaOp seeds,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)78 xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
79                                xla::XlaOp seeds, const xla::Shape& shape,
80                                xla::XlaOp minval, xla::XlaOp maxval) {
81   xla::XlaBuilder* builder = seeds.builder();
82 
83   xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
84   xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
85   xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
86   xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
87   xla::PrimitiveType type = shape.element_type();
88   switch (type) {
89     case xla::F32:
90     case xla::F64:
91       return xla::UniformFloatingPointDistribution(
92                  key, initial_state,
93                  GetBitGeneratorForDevice(device_type_string), minval, maxval,
94                  shape)
95           .value;
96     case xla::S32:  // fall through
97     case xla::S64:
98       return UniformIntDistribution(
99                  key, initial_state,
100                  GetBitGeneratorForDevice(device_type_string), minval, maxval,
101                  shape)
102           .value;
103       break;
104     default:
105       return builder->ReportError(xla::Unimplemented(
106           "Types other than F32, S32 and S64 are not implemented by "
107           "StatelessRngUniform; got %s",
108           xla::primitive_util::LowercasePrimitiveTypeName(type)));
109   }
110 }
111 
112 namespace {
113 
StatelessRngUniformFullInt(absl::string_view device_type_string,xla::XlaOp seeds,const xla::Shape & shape)114 xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
115                                       xla::XlaOp seeds,
116                                       const xla::Shape& shape) {
117   xla::XlaBuilder* builder = seeds.builder();
118 
119   xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
120   xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
121   xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
122   xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
123   xla::PrimitiveType type = shape.element_type();
124   xla::RngOutput output =
125       GetBitGeneratorForDevice(device_type_string)(key, initial_state, shape);
126   switch (type) {
127     case xla::U32:
128     case xla::U64:
129       return output.value;
130     case xla::S32:
131     case xla::S64:
132       return BitcastConvertType(output.value, type);
133     default:
134       return builder->ReportError(xla::Unimplemented(
135           "Types other than U32, S32, U64 and S64 are not implemented by "
136           "StatelessRngUniformFullInt; got: %s",
137           xla::primitive_util::LowercasePrimitiveTypeName(type)));
138   }
139 }
140 
141 class StatelessRandomUniformOp : public XlaOpKernel {
142  public:
StatelessRandomUniformOp(OpKernelConstruction * ctx)143   explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
144       : XlaOpKernel(ctx),
145         device_type_string_(ctx->device_type().type_string()) {
146     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
147   }
148 
Compile(XlaOpKernelContext * ctx)149   void Compile(XlaOpKernelContext* ctx) override {
150     xla::XlaBuilder* builder = ctx->builder();
151 
152     TensorShape shape;
153     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
154 
155     TensorShape seed_shape = ctx->InputShape(1);
156     OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
157                 errors::InvalidArgument("seed must have shape [2], not ",
158                                         seed_shape.DebugString()));
159     xla::XlaOp seed = ctx->Input(1);
160 
161     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
162     xla::Shape xla_shape;
163     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
164     xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
165 
166     xla::XlaOp uniform = StatelessRngUniform(
167         device_type_string_, seed, xla_shape,
168         xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
169         xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
170     uniform = MaybeConvertF32ToBF16(uniform, dtype_);
171     ctx->SetOutput(0, uniform);
172   }
173 
174  private:
175   DataType dtype_;
176   string device_type_string_;
177 
178   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
179 };
180 
181 // TODO(phawkins): generalize to non-float, non-int32 seed types.
182 REGISTER_XLA_OP(Name("StatelessRandomUniform")
183                     .CompileTimeConstantInput("shape")
184                     .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
185                     .TypeConstraint("Tseed", DT_INT32),
186                 StatelessRandomUniformOp);
187 
188 class StatelessRandomUniformIntOp : public XlaOpKernel {
189  public:
StatelessRandomUniformIntOp(OpKernelConstruction * ctx)190   explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
191       : XlaOpKernel(ctx),
192         device_type_string_(ctx->device_type().type_string()) {
193     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
194   }
195 
Compile(XlaOpKernelContext * ctx)196   void Compile(XlaOpKernelContext* ctx) override {
197     TensorShape shape;
198     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
199 
200     TensorShape seed_shape = ctx->InputShape(1);
201     OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
202                 errors::InvalidArgument("seed must have shape [2], not ",
203                                         seed_shape.DebugString()));
204     TensorShape minval_shape = ctx->InputShape(2);
205     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
206                 errors::InvalidArgument("minval must be scalar, got shape ",
207                                         minval_shape.DebugString()));
208     TensorShape maxval_shape = ctx->InputShape(3);
209     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
210                 errors::InvalidArgument("minval must be scalar, got shape ",
211                                         maxval_shape.DebugString()));
212 
213     xla::XlaOp seed = ctx->Input(1);
214     xla::XlaOp minval = ctx->Input(2);
215     xla::XlaOp maxval = ctx->Input(3);
216 
217     xla::Shape xla_shape;
218     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
219     xla::XlaOp uniform = StatelessRngUniform(device_type_string_, seed,
220                                              xla_shape, minval, maxval);
221 
222     ctx->SetOutput(0, uniform);
223   }
224 
225  private:
226   DataType dtype_;
227   string device_type_string_;
228 
229   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
230 };
231 
232 // TODO(phawkins): generalize to non-int32 seed types.
233 REGISTER_XLA_OP(Name("StatelessRandomUniformInt")
234                     .CompileTimeConstantInput("shape")
235                     .TypeConstraint("dtype", {DT_INT32, DT_INT64})
236                     .TypeConstraint("Tseed", DT_INT32),
237                 StatelessRandomUniformIntOp);
238 
239 class StatelessRandomUniformFullIntOp : public XlaOpKernel {
240  public:
StatelessRandomUniformFullIntOp(OpKernelConstruction * ctx)241   explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
242       : XlaOpKernel(ctx),
243         device_type_string_(ctx->device_type().type_string()) {
244     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
245   }
246 
Compile(XlaOpKernelContext * ctx)247   void Compile(XlaOpKernelContext* ctx) override {
248     TensorShape shape;
249     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
250 
251     TensorShape seed_shape = ctx->InputShape(1);
252     OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
253                 errors::InvalidArgument("seed must have shape [2], not ",
254                                         seed_shape.DebugString()));
255 
256     xla::XlaOp seed = ctx->Input(1);
257 
258     xla::Shape xla_shape;
259     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
260     xla::XlaOp uniform =
261         StatelessRngUniformFullInt(device_type_string_, seed, xla_shape);
262 
263     ctx->SetOutput(0, uniform);
264   }
265 
266  private:
267   DataType dtype_;
268   string device_type_string_;
269 
270   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
271 };
272 
273 // TODO(phawkins): generalize to non-int32 seed types.
274 REGISTER_XLA_OP(Name("StatelessRandomUniformFullInt")
275                     .CompileTimeConstantInput("shape")
276                     .TypeConstraint("dtype", {DT_INT32, DT_INT64})
277                     .TypeConstraint("Tseed", DT_INT32),
278                 StatelessRandomUniformFullIntOp);
279 
280 class StatelessRandomNormalOp : public XlaOpKernel {
281  public:
StatelessRandomNormalOp(OpKernelConstruction * ctx)282   explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
283       : XlaOpKernel(ctx),
284         device_type_string_(ctx->device_type().type_string()) {
285     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
286   }
287 
Compile(XlaOpKernelContext * ctx)288   void Compile(XlaOpKernelContext* ctx) override {
289     TensorShape shape;
290     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
291 
292     TensorShape seed_shape = ctx->InputShape(1);
293     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
294                 errors::InvalidArgument("seed must have shape [2], not ",
295                                         seed_shape.DebugString()));
296     xla::XlaOp seed = ctx->Input(1);
297     xla::Shape xla_shape;
298 
299     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
300     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
301 
302     xla::XlaBuilder* builder = seed.builder();
303     xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
304     xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
305     xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
306 
307     xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
308     xla::XlaOp normal =
309         xla::NormalFloatingPointDistribution(
310             key, initial_state, GetBitGeneratorForDevice(device_type_string_),
311             xla_shape)
312             .value;
313     normal = MaybeConvertF32ToBF16(normal, dtype_);
314     ctx->SetOutput(0, normal);
315   }
316 
317  private:
318   DataType dtype_;
319   string device_type_string_;
320 
321   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
322 };
323 
324 // TODO(phawkins): generalize to non-float, non-int32 seed types.
325 REGISTER_XLA_OP(Name("StatelessRandomNormal")
326                     .CompileTimeConstantInput("shape")
327                     .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
328                     .TypeConstraint("Tseed", DT_INT32),
329                 StatelessRandomNormalOp);
330 
331 class StatelessTruncatedNormalOp : public XlaOpKernel {
332  public:
StatelessTruncatedNormalOp(OpKernelConstruction * ctx)333   explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
334       : XlaOpKernel(ctx),
335         device_type_string_(ctx->device_type().type_string()) {
336     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
337   }
338 
Compile(XlaOpKernelContext * ctx)339   void Compile(XlaOpKernelContext* ctx) override {
340     TensorShape shape;
341     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
342 
343     TensorShape seed_shape = ctx->InputShape(1);
344     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
345                 errors::InvalidArgument("seed must have shape [2], not ",
346                                         seed_shape.DebugString()));
347     xla::XlaOp seed = ctx->Input(1);
348     xla::XlaBuilder* builder = ctx->builder();
349 
350     DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
351     xla::Shape xla_shape;
352     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
353     xla::XlaOp uniform = StatelessRngUniform(
354         device_type_string_, seed, xla_shape,
355         xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
356         xla::One(builder, xla_shape.element_type()));
357     xla::XlaOp truncated_normal = TruncatedNormal(uniform);
358     truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
359     ctx->SetOutput(0, truncated_normal);
360   }
361 
362  private:
363   DataType dtype_;
364   string device_type_string_;
365 
366   TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
367 };
368 
369 REGISTER_XLA_OP(Name("StatelessTruncatedNormal")
370                     .CompileTimeConstantInput("shape")
371                     .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16})
372                     .TypeConstraint("Tseed", DT_INT32),
373                 StatelessTruncatedNormalOp);
374 
375 }  // namespace
376 }  // namespace tensorflow
377