• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Native XLA implementations of simple binary Ops
17 
18 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/client/client_library.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/framework/kernel_def_builder.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 // A subclass of a XlaBinaryOp must build the computation that
34 // describes the (tensor,tensor)->tensor function to apply to each element of
35 // the input.
36 #define XLA_MAKE_BINARY(NAME, HLO)                                       \
37   class NAME##Op : public XlaBinaryOp {                                  \
38    public:                                                               \
39     explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {}   \
40     xla::XlaOp Computation(                                              \
41         XlaOpKernelContext* ctx, const xla::XlaOp& lhs,                  \
42         const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs, \
43         const absl::Span<const int64>& rhs_shape,                        \
44         const BCast& broadcast_helper,                                   \
45         const std::vector<int64>& extend_dimensions) override {          \
46       xla::XlaBuilder* b = ctx->builder();                               \
47       (void)b;                                                           \
48       (void)lhs_shape;                                                   \
49       (void)rhs_shape;                                                   \
50       (void)extend_dimensions;                                           \
51       return HLO;                                                        \
52     }                                                                    \
53   };                                                                     \
54   REGISTER_XLA_OP(Name(#NAME), NAME##Op)
55 
56 XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions));
57 XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions));
58 XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions));
59 XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
60 
61 XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
62 XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
63 
64 // Implementation of DivNoNan. Pseudo-code:
65 // if (y == 0) {
66 //   return 0
67 // } else {
68 //   return x / y;
69 // }
DivNoNanImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)70 static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
71                                xla::XlaOp y, const BCast& broadcast_helper) {
72   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
73   auto zero = XlaHelpers::Zero(b, dtype);
74   auto y_equals_0 = xla::Eq(y, zero);
75   auto zeros = xla::ZerosLike(x);
76   auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y));
77   return result;
78 }
79 XLA_MAKE_BINARY(DivNoNan,
80                 DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
81 
82 // Implementation of MulNoNan. Pseudo-code:
83 // if (y == 0) {
84 //   return 0
85 // } else {
86 //   return x * y;
87 // }
MulNoNanImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)88 static xla::XlaOp MulNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
89                                xla::XlaOp y, const BCast& broadcast_helper) {
90   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
91   auto zero = XlaHelpers::Zero(b, dtype);
92   auto y_equals_0 = xla::Eq(y, zero);
93   auto zeros = xla::ZerosLike(x);
94   auto result = xla::Select(y_equals_0, zeros, xla::Mul(x, y));
95   return result;
96 }
97 XLA_MAKE_BINARY(MulNoNan,
98                 MulNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
99 
100 // Implementation of FloorDiv.
101 //
102 // For floating-point values, simply returns floor(x / y).  For integers, does:
103 //
104 // if ((x < 0) != (y < 0)) {
105 //   T abs_x = std::abs(x);
106 //   T abs_y = std::abs(y);
107 //   return -(abs_x + abs_y - 1) / abs_y;
108 // } else {
109 //   return x / y;
110 // }
FloorDivImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)111 static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
112                                xla::XlaOp y, const BCast& broadcast_helper) {
113   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
114   if (DataTypeIsFloating(dtype)) {
115     return xla::Floor(xla::Div(x, y));
116   }
117   if (DataTypeIsUnsigned(dtype)) {
118     return xla::Div(x, y);
119   }
120   auto zero = XlaHelpers::Zero(b, dtype);
121   auto one = XlaHelpers::One(b, dtype);
122   auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero));
123   auto abs_x = xla::Abs(x);
124   auto abs_y = xla::Abs(y);
125   auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one));
126   return xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y));
127 }
128 XLA_MAKE_BINARY(FloorDiv,
129                 FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
130 
XlogyImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)131 xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y,
132                      const BCast& broadcast_helper) {
133   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
134   auto zero = xla::ZerosLike(x);
135   auto is_zero = xla::Eq(x, zero);
136   return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
137 }
138 XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper));
139 
XdivyImpl(xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)140 xla::XlaOp XdivyImpl(xla::XlaOp x, xla::XlaOp y,
141                      const BCast& broadcast_helper) {
142   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
143   auto zero = xla::ZerosLike(x);
144   auto is_zero = xla::Eq(x, zero);
145   return xla::Select(is_zero, zero, xla::Div(x, y));
146 }
147 XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper));
148 
149 // Implementation of FloorMod. Pseudo-code:
150 // T trunc_mod = std::fmod(x, y);
151 // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
FloorModImpl(xla::XlaBuilder * b,DataType dtype,xla::XlaOp x,xla::XlaOp y,const BCast & broadcast_helper)152 static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
153                                xla::XlaOp y, const BCast& broadcast_helper) {
154   std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
155   auto zero = XlaHelpers::Zero(b, dtype);
156   auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero));
157   auto trunc_mod = xla::Rem(x, y);
158   return xla::Select(same_sign, trunc_mod, xla::Rem(xla::Add(trunc_mod, y), y));
159 }
160 XLA_MAKE_BINARY(FloorMod,
161                 FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper));
162 
163 XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions));
164 XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions));
165 XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions));
166 
167 XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions));
168 XLA_MAKE_BINARY(RightShift,
169                 (DataTypeIsUnsigned(ctx->input_type(0))
170                      ? xla::ShiftRightLogical(lhs, rhs, extend_dimensions)
171                      : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions)));
172 
173 XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions));
174 XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions));
175 XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions));
176 XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions));
177 XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions));
178 XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions));
179 XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs))));
180 XLA_MAKE_BINARY(
181     RsqrtGrad,
182     xla::Mul((lhs * lhs) * lhs,
183              xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
184              extend_dimensions));
185 XLA_MAKE_BINARY(
186     SqrtGrad,
187     xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
188              lhs, extend_dimensions));
189 
190 XLA_MAKE_BINARY(SquaredDifference,
191                 xla::Square(xla::Sub(lhs, rhs, extend_dimensions)));
192 
193 XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions));
194 XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions));
195 
196 // Comparison ops
197 XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions));
198 XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions));
199 XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions));
200 XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions));
201 XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions));
202 XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions));
203 
204 // Non-linear ops
205 XLA_MAKE_BINARY(SigmoidGrad,
206                 xla::Mul(xla::Mul(rhs, lhs),
207                          xla::Sub(XlaHelpers::One(b, input_type(0)), lhs)));
208 
209 XLA_MAKE_BINARY(SoftplusGrad,
210                 xla::Div(lhs, xla::Add(xla::Exp(xla::Neg(rhs)),
211                                        XlaHelpers::One(b, input_type(1)))));
212 
213 // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2
214 XLA_MAKE_BINARY(SoftsignGrad,
215                 xla::Div(lhs,
216                          xla::Square(xla::Add(XlaHelpers::One(b, input_type(0)),
217                                               xla::Abs(rhs)))));
218 
219 XLA_MAKE_BINARY(TanhGrad,
220                 xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)),
221                                        xla::Mul(lhs, lhs))));
222 
223 XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions));
224 
225 XLA_MAKE_BINARY(NextAfter, xla::NextAfter(lhs, rhs));
226 
227 #undef XLA_MAKE_BINARY
228 
229 class ApproximateEqualOp : public XlaOpKernel {
230  public:
ApproximateEqualOp(OpKernelConstruction * ctx)231   explicit ApproximateEqualOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
232     OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance_));
233   }
234 
235   // Computes the max of the scalar input x and 0.
Compile(XlaOpKernelContext * ctx)236   void Compile(XlaOpKernelContext* ctx) override {
237     xla::XlaBuilder* b = ctx->builder();
238     auto abs = xla::Abs(xla::Sub(ctx->Input(0), ctx->Input(1)));
239     auto abs_shape = b->GetShape(abs);
240     OP_REQUIRES_OK(ctx, abs_shape.status());
241     auto abs_type = abs_shape.ValueOrDie().element_type();
242     auto result =
243         xla::Lt(abs, xla::ConvertElementType(
244                          xla::ConstantR0<float>(b, tolerance_), abs_type));
245     ctx->SetOutput(0, result);
246   }
247 
248  private:
249   float tolerance_;
250 };
251 REGISTER_XLA_OP(Name("ApproximateEqual"), ApproximateEqualOp);
252 
253 }  // namespace
254 }  // namespace tensorflow
255