1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17
18 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
19 #include "tensorflow/compiler/tf2xla/lib/random.h"
20 #include "tensorflow/compiler/tf2xla/shape_util.h"
21 #include "tensorflow/compiler/tf2xla/type_util.h"
22 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/compiler/xla/client/lib/constants.h"
26 #include "tensorflow/compiler/xla/client/lib/math.h"
27 #include "tensorflow/compiler/xla/client/lib/prng.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/lib/math/math_util.h"
33
34 namespace tensorflow {
35
MaybeConvertF32ToBF16(xla::XlaOp input,DataType dtype)36 xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
37 if (dtype == DT_BFLOAT16) {
38 xla::XlaBuilder* builder = input.builder();
39 auto output = xla::BitcastConvertType(input, xla::U32) &
40 xla::ConstantR0<uint32>(builder, 0xFFFF0000);
41 return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
42 xla::BF16);
43 } else {
44 return input;
45 }
46 }
47
Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform)48 xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) {
49 // Convert uniform distribution to normal distribution by computing
50 // sqrt(2) * erfinv(x)
51 return xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
52 }
53
54 // A wrapper of xla::StatelessRngUniform. Returns an op that produces random
55 // values with uniform distribution in the range [minval, maxval) for the given
56 // shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and
57 // S64 are implemented.
StatelessRandomUniformImpl(const xla::Shape & shape,DataType dtype,xla::XlaOp seed,xla::XlaOp minval,xla::XlaOp maxval)58 xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType dtype,
59 xla::XlaOp seed, xla::XlaOp minval,
60 xla::XlaOp maxval) {
61 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
62 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
63 return xla::StatelessRngUniform({seed0, seed1}, shape, minval, maxval);
64 }
65
66 namespace {
67
68 class StatelessRandomUniformOp : public XlaOpKernel {
69 public:
StatelessRandomUniformOp(OpKernelConstruction * ctx)70 explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
71 : XlaOpKernel(ctx) {
72 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
73 }
74
Compile(XlaOpKernelContext * ctx)75 void Compile(XlaOpKernelContext* ctx) override {
76 xla::XlaBuilder* builder = ctx->builder();
77
78 TensorShape shape;
79 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
80
81 TensorShape seed_shape = ctx->InputShape(1);
82 OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
83 errors::InvalidArgument("seed must have shape [2], not ",
84 seed_shape.DebugString()));
85 xla::XlaOp seed = ctx->Input(1);
86
87 xla::Shape xla_shape;
88 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
89 xla::XlaOp uniform = StatelessRandomUniformImpl(
90 xla_shape, dtype_, seed, xla::ConstantR0<float>(builder, 0.0),
91 xla::ConstantR0<float>(builder, 1.0));
92 uniform = MaybeConvertF32ToBF16(uniform, dtype_);
93 ctx->SetOutput(0, uniform);
94 }
95
96 private:
97 DataType dtype_;
98
99 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
100 };
101
102 // TODO(phawkins): generalize to non-float, non-int32 seed types.
103 REGISTER_XLA_OP(Name("StatelessRandomUniform")
104 .CompileTimeConstantInput("shape")
105 .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16})
106 .TypeConstraint("Tseed", DT_INT32),
107 StatelessRandomUniformOp);
108
109 class StatelessRandomUniformIntOp : public XlaOpKernel {
110 public:
StatelessRandomUniformIntOp(OpKernelConstruction * ctx)111 explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
112 : XlaOpKernel(ctx) {
113 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
114 }
115
Compile(XlaOpKernelContext * ctx)116 void Compile(XlaOpKernelContext* ctx) override {
117 TensorShape shape;
118 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
119
120 TensorShape seed_shape = ctx->InputShape(1);
121 OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
122 errors::InvalidArgument("seed must have shape [2], not ",
123 seed_shape.DebugString()));
124 TensorShape minval_shape = ctx->InputShape(2);
125 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
126 errors::InvalidArgument("minval must be scalar, got shape ",
127 minval_shape.DebugString()));
128 TensorShape maxval_shape = ctx->InputShape(3);
129 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
130 errors::InvalidArgument("minval must be scalar, got shape ",
131 maxval_shape.DebugString()));
132
133 xla::XlaOp seed = ctx->Input(1);
134 xla::XlaOp minval = ctx->Input(2);
135 xla::XlaOp maxval = ctx->Input(3);
136
137 xla::Shape xla_shape;
138 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
139 xla::XlaOp uniform =
140 StatelessRandomUniformImpl(xla_shape, dtype_, seed, minval, maxval);
141 ctx->SetOutput(0, uniform);
142 }
143
144 private:
145 DataType dtype_;
146
147 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
148 };
149
150 // TODO(phawkins): generalize to non-int32 seed types.
151 REGISTER_XLA_OP(Name("StatelessRandomUniformInt")
152 .CompileTimeConstantInput("shape")
153 .TypeConstraint("dtype", {DT_INT32, DT_INT64})
154 .TypeConstraint("Tseed", DT_INT32),
155 StatelessRandomUniformIntOp);
156
157 class StatelessRandomNormalOp : public XlaOpKernel {
158 public:
StatelessRandomNormalOp(OpKernelConstruction * ctx)159 explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
160 : XlaOpKernel(ctx) {
161 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
162 }
163
Compile(XlaOpKernelContext * ctx)164 void Compile(XlaOpKernelContext* ctx) override {
165 TensorShape shape;
166 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
167
168 TensorShape seed_shape = ctx->InputShape(1);
169 OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
170 errors::InvalidArgument("seed must have shape [2], not ",
171 seed_shape.DebugString()));
172 xla::XlaOp seed = ctx->Input(1);
173 xla::XlaBuilder* builder = ctx->builder();
174 xla::Shape xla_shape;
175 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
176 xla::XlaOp uniform = StatelessRandomUniformImpl(
177 xla_shape, dtype_, seed,
178 xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
179 xla::ConstantR0<float>(builder, 1.0));
180 xla::XlaOp normal = Uniform2NormalUsingSqrtErfinv(uniform);
181 normal = MaybeConvertF32ToBF16(normal, dtype_);
182 ctx->SetOutput(0, normal);
183 }
184
185 private:
186 DataType dtype_;
187
188 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
189 };
190
191 // TODO(phawkins): generalize to non-float, non-int32 seed types.
192 REGISTER_XLA_OP(Name("StatelessRandomNormal")
193 .CompileTimeConstantInput("shape")
194 .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16})
195 .TypeConstraint("Tseed", DT_INT32),
196 StatelessRandomNormalOp);
197
198 class StatelessTruncatedNormalOp : public XlaOpKernel {
199 public:
StatelessTruncatedNormalOp(OpKernelConstruction * ctx)200 explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
201 : XlaOpKernel(ctx) {
202 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
203 }
204
Compile(XlaOpKernelContext * ctx)205 void Compile(XlaOpKernelContext* ctx) override {
206 TensorShape shape;
207 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
208
209 TensorShape seed_shape = ctx->InputShape(1);
210 OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
211 errors::InvalidArgument("seed must have shape [2], not ",
212 seed_shape.DebugString()));
213 xla::XlaOp seed = ctx->Input(1);
214 xla::XlaBuilder* builder = ctx->builder();
215
216 xla::Shape xla_shape;
217 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
218 xla::XlaOp uniform = StatelessRandomUniformImpl(
219 xla_shape, dtype_, seed,
220 xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
221 xla::One(builder, xla_shape.element_type()));
222 xla::XlaOp truncated_normal = TruncatedNormal(uniform);
223 truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
224 ctx->SetOutput(0, truncated_normal);
225 }
226
227 private:
228 DataType dtype_;
229
230 TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
231 };
232
233 REGISTER_XLA_OP(Name("StatelessTruncatedNormal")
234 .CompileTimeConstantInput("shape")
235 .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16})
236 .TypeConstraint("Tseed", DT_INT32),
237 StatelessTruncatedNormalOp);
238
239 } // namespace
240 } // namespace tensorflow
241