1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Native XLA implementations of simple binary Ops
17
18 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/client/client_library.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/framework/kernel_def_builder.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29
30 namespace tensorflow {
31 namespace {
32
33 // A subclass of a XlaBinaryOp must build the computation that
34 // describes the (tensor,tensor)->tensor function to apply to each element of
35 // the input.
36 #define XLA_MAKE_BINARY(NAME, HLO) \
37 class NAME##Op : public XlaBinaryOp { \
38 public: \
39 explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \
40 xla::XlaOp Computation( \
41 XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \
42 const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs, \
43 const absl::Span<const int64>& rhs_shape, \
44 const BCast& broadcast_helper, \
45 const std::vector<int64>& extend_dimensions) override { \
46 xla::XlaBuilder* b = ctx->builder(); \
47 (void)b; \
48 (void)lhs_shape; \
49 (void)rhs_shape; \
50 (void)extend_dimensions; \
51 return HLO; \
52 } \
53 }; \
54 REGISTER_XLA_OP(Name(#NAME), NAME##Op)
55
56 XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions));
57 XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions));
58 XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions));
59 XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
60
61 XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
62 XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
63
64 // Implementation of DivNoNan. Pseudo-code:
65 // if (y == 0) {
66 // return 0
67 // } else {
68 // return x / y;
69 // }
DivNoNanImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)70 static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
71 xla::XlaOp y, const BCast& broadcast_helper) {
72 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
73 auto zero = XlaHelpers::Zero(b, dtype);
74 auto y_equals_0 = xla::Eq(y, zero);
75 auto zeros = xla::ZerosLike(x);
76 auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y));
77 return result;
78 }
79 XLA_MAKE_BINARY(DivNoNan,
80 DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
81
82 // Implementation of MulNoNan. Pseudo-code:
83 // if (y == 0) {
84 // return 0
85 // } else {
86 // return x * y;
87 // }
MulNoNanImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)88 static xla::XlaOp MulNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
89 xla::XlaOp y, const BCast& broadcast_helper) {
90 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
91 auto zero = XlaHelpers::Zero(b, dtype);
92 auto y_equals_0 = xla::Eq(y, zero);
93 auto zeros = xla::ZerosLike(x);
94 auto result = xla::Select(y_equals_0, zeros, xla::Mul(x, y));
95 return result;
96 }
97 XLA_MAKE_BINARY(MulNoNan,
98 MulNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
99
100 // Implementation of FloorDiv.
101 //
102 // For floating-point values, simply returns floor(x / y). For integers, does:
103 //
104 // if ((x < 0) != (y < 0)) {
105 // T abs_x = std::abs(x);
106 // T abs_y = std::abs(y);
107 // return -(abs_x + abs_y - 1) / abs_y;
108 // } else {
109 // return x / y;
110 // }
FloorDivImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)111 static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
112 xla::XlaOp y, const BCast& broadcast_helper) {
113 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
114 if (DataTypeIsFloating(dtype)) {
115 return xla::Floor(xla::Div(x, y));
116 }
117 if (DataTypeIsUnsigned(dtype)) {
118 return xla::Div(x, y);
119 }
120 auto zero = XlaHelpers::Zero(b, dtype);
121 auto one = XlaHelpers::One(b, dtype);
122 auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero));
123 auto abs_x = xla::Abs(x);
124 auto abs_y = xla::Abs(y);
125 auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one));
126 return xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y));
127 }
128 XLA_MAKE_BINARY(FloorDiv,
129 FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
130
XlogyImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)131 xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y,
132 const BCast& broadcast_helper) {
133 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
134 auto zero = xla::ZerosLike(x);
135 auto is_zero = xla::Eq(x, zero);
136 return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
137 }
138 XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper));
139
XdivyImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)140 xla::XlaOp XdivyImpl(xla::XlaOp x, xla::XlaOp y,
141 const BCast& broadcast_helper) {
142 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
143 auto zero = xla::ZerosLike(x);
144 auto is_zero = xla::Eq(x, zero);
145 return xla::Select(is_zero, zero, xla::Div(x, y));
146 }
147 XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper));
148
149 // Implementation of FloorMod. Pseudo-code:
150 // T trunc_mod = std::fmod(x, y);
151 // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
FloorModImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)152 static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
153 xla::XlaOp y, const BCast& broadcast_helper) {
154 std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
155 auto zero = XlaHelpers::Zero(b, dtype);
156 auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero));
157 auto trunc_mod = xla::Rem(x, y);
158 return xla::Select(same_sign, trunc_mod, xla::Rem(xla::Add(trunc_mod, y), y));
159 }
160 XLA_MAKE_BINARY(FloorMod,
161 FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper));
162
163 XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions));
164 XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions));
165 XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions));
166
167 XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions));
168 XLA_MAKE_BINARY(RightShift,
169 (DataTypeIsUnsigned(ctx->input_type(0))
170 ? xla::ShiftRightLogical(lhs, rhs, extend_dimensions)
171 : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions)));
172
173 XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions));
174 XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions));
175 XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions));
176 XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions));
177 XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions));
178 XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions));
179 XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs))));
180 XLA_MAKE_BINARY(
181 RsqrtGrad,
182 xla::Mul((lhs * lhs) * lhs,
183 xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
184 extend_dimensions));
185 XLA_MAKE_BINARY(
186 SqrtGrad,
187 xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
188 lhs, extend_dimensions));
189
190 XLA_MAKE_BINARY(SquaredDifference,
191 xla::Square(xla::Sub(lhs, rhs, extend_dimensions)));
192
193 XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions));
194 XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions));
195
196 // Comparison ops
197 XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions));
198 XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions));
199 XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions));
200 XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions));
201 XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions));
202 XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions));
203
204 // Non-linear ops
205 XLA_MAKE_BINARY(SigmoidGrad,
206 xla::Mul(xla::Mul(rhs, lhs),
207 xla::Sub(XlaHelpers::One(b, input_type(0)), lhs)));
208
209 XLA_MAKE_BINARY(SoftplusGrad,
210 xla::Div(lhs, xla::Add(xla::Exp(xla::Neg(rhs)),
211 XlaHelpers::One(b, input_type(1)))));
212
213 // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2
214 XLA_MAKE_BINARY(SoftsignGrad,
215 xla::Div(lhs,
216 xla::Square(xla::Add(XlaHelpers::One(b, input_type(0)),
217 xla::Abs(rhs)))));
218
219 XLA_MAKE_BINARY(TanhGrad,
220 xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)),
221 xla::Mul(lhs, lhs))));
222
223 XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions));
224
225 XLA_MAKE_BINARY(NextAfter, xla::NextAfter(lhs, rhs));
226
227 #undef XLA_MAKE_BINARY
228
229 class ApproximateEqualOp : public XlaOpKernel {
230 public:
ApproximateEqualOp(OpKernelConstruction * ctx)231 explicit ApproximateEqualOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
232 OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance_));
233 }
234
235 // Computes the max of the scalar input x and 0.
Compile(XlaOpKernelContext * ctx)236 void Compile(XlaOpKernelContext* ctx) override {
237 xla::XlaBuilder* b = ctx->builder();
238 auto abs = xla::Abs(xla::Sub(ctx->Input(0), ctx->Input(1)));
239 auto abs_shape = b->GetShape(abs);
240 OP_REQUIRES_OK(ctx, abs_shape.status());
241 auto abs_type = abs_shape.ValueOrDie().element_type();
242 auto result =
243 xla::Lt(abs, xla::ConvertElementType(
244 xla::ConstantR0<float>(b, tolerance_), abs_type));
245 ctx->SetOutput(0, result);
246 }
247
248 private:
249 float tolerance_;
250 };
251 REGISTER_XLA_OP(Name("ApproximateEqual"), ApproximateEqualOp);
252
253 } // namespace
254 } // namespace tensorflow
255