• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 
18 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
19 #include "tensorflow/compiler/tf2xla/lib/random.h"
20 #include "tensorflow/compiler/tf2xla/shape_util.h"
21 #include "tensorflow/compiler/tf2xla/type_util.h"
22 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/compiler/xla/client/lib/constants.h"
26 #include "tensorflow/compiler/xla/client/lib/math.h"
27 #include "tensorflow/compiler/xla/client/lib/prng.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/lib/math/math_util.h"
33 
34 namespace tensorflow {
35 
MaybeConvertF32ToBF16(xla::XlaOp input,DataType dtype)36 xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
37   if (dtype == DT_BFLOAT16) {
38     xla::XlaBuilder* builder = input.builder();
39     auto output = xla::BitcastConvertType(input, xla::U32) &
40                   xla::ConstantR0<uint32>(builder, 0xFFFF0000);
41     return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
42                                    xla::BF16);
43   } else {
44     return input;
45   }
46 }
47 
Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform)48 xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) {
49   // Convert uniform distribution to normal distribution by computing
50   // sqrt(2) * erfinv(x)
51   return xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
52 }
53 
54 // A wrapper of xla::StatelessRngUniform. Returns an op that produces random
55 // values with uniform distribution in the range [minval, maxval) for the given
56 // shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and
57 // S64 are implemented.
StatelessRandomUniformImpl(const xla::Shape & shape,DataType dtype,xla::XlaOp seed,xla::XlaOp minval,xla::XlaOp maxval)58 xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType dtype,
59                                       xla::XlaOp seed, xla::XlaOp minval,
60                                       xla::XlaOp maxval) {
61   xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
62   xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
63   return xla::StatelessRngUniform({seed0, seed1}, shape, minval, maxval);
64 }
65 
66 namespace {
67 
68 class StatelessRandomUniformOp : public XlaOpKernel {
69  public:
StatelessRandomUniformOp(OpKernelConstruction * ctx)70   explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
71       : XlaOpKernel(ctx) {
72     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
73   }
74 
Compile(XlaOpKernelContext * ctx)75   void Compile(XlaOpKernelContext* ctx) override {
76     xla::XlaBuilder* builder = ctx->builder();
77 
78     TensorShape shape;
79     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
80 
81     TensorShape seed_shape = ctx->InputShape(1);
82     OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
83                 errors::InvalidArgument("seed must have shape [2], not ",
84                                         seed_shape.DebugString()));
85     xla::XlaOp seed = ctx->Input(1);
86 
87     xla::Shape xla_shape;
88     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
89     xla::XlaOp uniform = StatelessRandomUniformImpl(
90         xla_shape, dtype_, seed, xla::ConstantR0<float>(builder, 0.0),
91         xla::ConstantR0<float>(builder, 1.0));
92     uniform = MaybeConvertF32ToBF16(uniform, dtype_);
93     ctx->SetOutput(0, uniform);
94   }
95 
96  private:
97   DataType dtype_;
98 
99   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
100 };
101 
102 // TODO(phawkins): generalize to non-float, non-int32 seed types.
103 REGISTER_XLA_OP(Name("StatelessRandomUniform")
104                     .CompileTimeConstantInput("shape")
105                     .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16})
106                     .TypeConstraint("Tseed", DT_INT32),
107                 StatelessRandomUniformOp);
108 
109 class StatelessRandomUniformIntOp : public XlaOpKernel {
110  public:
StatelessRandomUniformIntOp(OpKernelConstruction * ctx)111   explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
112       : XlaOpKernel(ctx) {
113     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
114   }
115 
Compile(XlaOpKernelContext * ctx)116   void Compile(XlaOpKernelContext* ctx) override {
117     TensorShape shape;
118     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
119 
120     TensorShape seed_shape = ctx->InputShape(1);
121     OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
122                 errors::InvalidArgument("seed must have shape [2], not ",
123                                         seed_shape.DebugString()));
124     TensorShape minval_shape = ctx->InputShape(2);
125     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
126                 errors::InvalidArgument("minval must be scalar, got shape ",
127                                         minval_shape.DebugString()));
128     TensorShape maxval_shape = ctx->InputShape(3);
129     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
130                 errors::InvalidArgument("minval must be scalar, got shape ",
131                                         maxval_shape.DebugString()));
132 
133     xla::XlaOp seed = ctx->Input(1);
134     xla::XlaOp minval = ctx->Input(2);
135     xla::XlaOp maxval = ctx->Input(3);
136 
137     xla::Shape xla_shape;
138     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
139     xla::XlaOp uniform =
140         StatelessRandomUniformImpl(xla_shape, dtype_, seed, minval, maxval);
141     ctx->SetOutput(0, uniform);
142   }
143 
144  private:
145   DataType dtype_;
146 
147   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
148 };
149 
150 // TODO(phawkins): generalize to non-int32 seed types.
151 REGISTER_XLA_OP(Name("StatelessRandomUniformInt")
152                     .CompileTimeConstantInput("shape")
153                     .TypeConstraint("dtype", {DT_INT32, DT_INT64})
154                     .TypeConstraint("Tseed", DT_INT32),
155                 StatelessRandomUniformIntOp);
156 
157 class StatelessRandomNormalOp : public XlaOpKernel {
158  public:
StatelessRandomNormalOp(OpKernelConstruction * ctx)159   explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
160       : XlaOpKernel(ctx) {
161     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
162   }
163 
Compile(XlaOpKernelContext * ctx)164   void Compile(XlaOpKernelContext* ctx) override {
165     TensorShape shape;
166     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
167 
168     TensorShape seed_shape = ctx->InputShape(1);
169     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
170                 errors::InvalidArgument("seed must have shape [2], not ",
171                                         seed_shape.DebugString()));
172     xla::XlaOp seed = ctx->Input(1);
173     xla::XlaBuilder* builder = ctx->builder();
174     xla::Shape xla_shape;
175     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
176     xla::XlaOp uniform = StatelessRandomUniformImpl(
177         xla_shape, dtype_, seed,
178         xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
179         xla::ConstantR0<float>(builder, 1.0));
180     xla::XlaOp normal = Uniform2NormalUsingSqrtErfinv(uniform);
181     normal = MaybeConvertF32ToBF16(normal, dtype_);
182     ctx->SetOutput(0, normal);
183   }
184 
185  private:
186   DataType dtype_;
187 
188   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
189 };
190 
191 // TODO(phawkins): generalize to non-float, non-int32 seed types.
192 REGISTER_XLA_OP(Name("StatelessRandomNormal")
193                     .CompileTimeConstantInput("shape")
194                     .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16})
195                     .TypeConstraint("Tseed", DT_INT32),
196                 StatelessRandomNormalOp);
197 
198 class StatelessTruncatedNormalOp : public XlaOpKernel {
199  public:
StatelessTruncatedNormalOp(OpKernelConstruction * ctx)200   explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
201       : XlaOpKernel(ctx) {
202     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
203   }
204 
Compile(XlaOpKernelContext * ctx)205   void Compile(XlaOpKernelContext* ctx) override {
206     TensorShape shape;
207     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
208 
209     TensorShape seed_shape = ctx->InputShape(1);
210     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
211                 errors::InvalidArgument("seed must have shape [2], not ",
212                                         seed_shape.DebugString()));
213     xla::XlaOp seed = ctx->Input(1);
214     xla::XlaBuilder* builder = ctx->builder();
215 
216     xla::Shape xla_shape;
217     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
218     xla::XlaOp uniform = StatelessRandomUniformImpl(
219         xla_shape, dtype_, seed,
220         xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
221         xla::One(builder, xla_shape.element_type()));
222     xla::XlaOp truncated_normal = TruncatedNormal(uniform);
223     truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
224     ctx->SetOutput(0, truncated_normal);
225   }
226 
227  private:
228   DataType dtype_;
229 
230   TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
231 };
232 
233 REGISTER_XLA_OP(Name("StatelessTruncatedNormal")
234                     .CompileTimeConstantInput("shape")
235                     .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16})
236                     .TypeConstraint("Tseed", DT_INT32),
237                 StatelessTruncatedNormalOp);
238 
239 }  // namespace
240 }  // namespace tensorflow
241