• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
18 
19 // See docs in ../ops/math_ops.cc.
20 #define _USE_MATH_DEFINES
21 #include <cmath>
22 
23 #define EIGEN_USE_THREADS
24 
25 #include "tensorflow/core/platform/bfloat16.h"
26 
27 
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/variant_op_registry.h"
32 #include "tensorflow/core/kernels/cwise_ops.h"
33 #include "tensorflow/core/kernels/cwise_ops_gradients.h"
34 #include "tensorflow/core/kernels/fill_functor.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/util/bcast.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 class BinaryOpShared : public OpKernel {
44  public:
45   explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in);
46 
47  protected:
48   struct BinaryOpState {
49     // Sets up bcast with the shape of in0 and in1, ensures that the bcast
50     // is valid, and if so, set out, either by allocating a new buffer using
51     // ctx->output(...) or by creating an alias for an owned input buffer for
52     // in-place computation.
53     // Caller must check ctx->status() upon return for non-ok status.
54     // If ctx->status().ok() is true, then out is guaranteed to be allocated.
55     explicit BinaryOpState(OpKernelContext* ctx);
56 
57     const Tensor& in0;
58     const Tensor& in1;
59 
60     BCast bcast;
61     Tensor* out = nullptr;
62     int64 out_num_elements;
63 
64     int64 in0_num_elements;
65     int64 in1_num_elements;
66 
67     int ndims;
68     bool result;
69   };
70 
71   void SetUnimplementedError(OpKernelContext* ctx);
72   void SetComputeError(OpKernelContext* ctx);
73 };
74 
75 // Coefficient-wise binary operations:
76 //   Device: E.g., CPUDevice, GPUDevice.
77 //   Functor: defined in cwise_ops.h. E.g., functor::add.
78 template <typename Device, typename Functor>
79 class BinaryOp : public BinaryOpShared {
80  public:
81   typedef typename Functor::in_type Tin;    // Input scalar data type.
82   typedef typename Functor::out_type Tout;  // Output scalar data type.
83 
BinaryOp(OpKernelConstruction * ctx)84   explicit BinaryOp(OpKernelConstruction* ctx)
85       : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(),
86                        DataTypeToEnum<Tin>::v()) {}
87 
Compute(OpKernelContext * ctx)88   void Compute(OpKernelContext* ctx) override {
89     const Tensor& input_0 = ctx->input(0);
90     const Tensor& input_1 = ctx->input(1);
91     const Device& eigen_device = ctx->eigen_device<Device>();
92     bool error = false;
93     bool* const error_ptr = Functor::has_errors ? &error : nullptr;
94 
95     // NOTE: Handle three simple cases before building the BinaryOpState, which
96     // is relatively expensive for small operations.
97     if (input_0.shape() == input_1.shape()) {
98       // tensor op tensor with no broadcasting.
99       Tensor* out;
100       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
101                               {0, 1}, 0, input_0.shape(), &out));
102       functor::BinaryFunctor<Device, Functor, 1>()(
103           eigen_device, out->template flat<Tout>(),
104           input_0.template flat<Tin>(), input_1.template flat<Tin>(),
105           error_ptr);
106       if (Functor::has_errors && error) {
107         SetComputeError(ctx);
108       }
109       return;
110     } else if (input_0.shape().dims() == 0) {
111       // scalar op tensor.
112       Tensor* out;
113       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
114                               {1}, 0, input_1.shape(), &out));
115 
116       functor::BinaryFunctor<Device, Functor, 1>().Left(
117           eigen_device, out->template flat<Tout>(),
118           input_0.template scalar<Tin>(), input_1.template flat<Tin>(),
119           error_ptr);
120       if (Functor::has_errors && error) {
121         SetComputeError(ctx);
122       }
123       return;
124     } else if (input_1.shape().dims() == 0) {
125       // tensor op scalar.
126       Tensor* out;
127       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
128                               {0}, 0, input_0.shape(), &out));
129       functor::BinaryFunctor<Device, Functor, 1>().Right(
130           eigen_device, out->template flat<Tout>(),
131           input_0.template flat<Tin>(), input_1.template scalar<Tin>(),
132           error_ptr);
133       if (Functor::has_errors && error) {
134         SetComputeError(ctx);
135       }
136       return;
137     }
138 
139     // 'state': Shared helper not dependent on T to reduce code size
140     BinaryOpState state(ctx);
141     if (ctx->status().code() == error::RESOURCE_EXHAUSTED) {
142       // Stop when BinaryOpState's constructor failed due to OOM.
143       return;
144     }
145     auto& bcast = state.bcast;
146     Tensor* out = state.out;
147     if (!bcast.IsValid()) {
148       if (ctx->status().ok()) {
149         if (state.result) {
150           functor::SetOneFunctor<Device, bool>()(eigen_device,
151                                                  out->flat<bool>());
152         } else {
153           functor::SetZeroFunctor<Device, bool>()(eigen_device,
154                                                   out->flat<bool>());
155         }
156       }
157       return;
158     }
159 
160     auto& in0 = state.in0;
161     auto& in1 = state.in1;
162     if (state.out_num_elements == 0) {
163       return;
164     }
165 
166     const int ndims = state.ndims;
167     if (ndims <= 1) {
168       auto out_flat = out->flat<Tout>();
169       if (state.in1_num_elements == 1) {
170         // tensor op scalar
171         functor::BinaryFunctor<Device, Functor, 1>().Right(
172             eigen_device, out_flat, in0.template flat<Tin>(),
173             in1.template scalar<Tin>(), error_ptr);
174       } else if (state.in0_num_elements == 1) {
175         // scalar op tensor
176         functor::BinaryFunctor<Device, Functor, 1>().Left(
177             eigen_device, out_flat, in0.template scalar<Tin>(),
178             in1.template flat<Tin>(), error_ptr);
179       } else {
180         functor::BinaryFunctor<Device, Functor, 1>()(
181             eigen_device, out_flat, in0.template flat<Tin>(),
182             in1.template flat<Tin>(), error_ptr);
183       }
184     } else if (ndims == 2) {
185       functor::BinaryFunctor<Device, Functor, 2>().BCast(
186           eigen_device, out->shaped<Tout, 2>(bcast.result_shape()),
187           in0.template shaped<Tin, 2>(bcast.x_reshape()),
188           BCast::ToIndexArray<2>(bcast.x_bcast()),
189           in1.template shaped<Tin, 2>(bcast.y_reshape()),
190           BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr);
191     } else if (ndims == 3) {
192       functor::BinaryFunctor<Device, Functor, 3>().BCast(
193           eigen_device, out->shaped<Tout, 3>(bcast.result_shape()),
194           in0.template shaped<Tin, 3>(bcast.x_reshape()),
195           BCast::ToIndexArray<3>(bcast.x_bcast()),
196           in1.template shaped<Tin, 3>(bcast.y_reshape()),
197           BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr);
198     } else if (ndims == 4) {
199       functor::BinaryFunctor<Device, Functor, 4>().BCast(
200           eigen_device, out->shaped<Tout, 4>(bcast.result_shape()),
201           in0.template shaped<Tin, 4>(bcast.x_reshape()),
202           BCast::ToIndexArray<4>(bcast.x_bcast()),
203           in1.template shaped<Tin, 4>(bcast.y_reshape()),
204           BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr);
205     } else if (ndims == 5) {
206       functor::BinaryFunctor<Device, Functor, 5>().BCast(
207           eigen_device, out->shaped<Tout, 5>(bcast.result_shape()),
208           in0.template shaped<Tin, 5>(bcast.x_reshape()),
209           BCast::ToIndexArray<5>(bcast.x_bcast()),
210           in1.template shaped<Tin, 5>(bcast.y_reshape()),
211           BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr);
212     } else {
213       SetUnimplementedError(ctx);
214     }
215     if (Functor::has_errors && error) {
216       SetComputeError(ctx);
217     }
218   }
219 };
220 
221 template <typename Device, typename T>
222 class ApproximateEqualOp : public OpKernel {
223  public:
ApproximateEqualOp(OpKernelConstruction * context)224   explicit ApproximateEqualOp(OpKernelConstruction* context)
225       : OpKernel(context) {
226     float tolerance;
227     OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance));
228     tolerance_ = T(tolerance);
229   }
Compute(OpKernelContext * context)230   void Compute(OpKernelContext* context) override {
231     const Tensor& x_input = context->input(0);
232     const Tensor& y_input = context->input(1);
233     OP_REQUIRES(
234         context, x_input.shape() == y_input.shape(),
235         errors::InvalidArgument("x and y must be of the same shape. ",
236                                 "x shape: ", x_input.shape().DebugString(),
237                                 ". y shape: ", y_input.shape().DebugString()));
238     Tensor* z_output = nullptr;
239     OP_REQUIRES_OK(context,
240                    context->allocate_output(0, x_input.shape(), &z_output));
241     const Device& d = context->eigen_device<Device>();
242     typename TTypes<T>::ConstFlat x(x_input.flat<T>());
243     typename TTypes<T>::ConstFlat y(y_input.flat<T>());
244     typename TTypes<bool>::Flat z(z_output->flat<bool>());
245     functor::ApproximateEqual<Device, T>()(d, x, y, tolerance_, z);
246   }
247 
248  private:
249   T tolerance_;
250 };
251 
252 // Basic coefficient-wise binary operations that are known to not require
253 // any broadcasting. This is the case for example of the gradients of
254 // unary operations.
255 //   Device: E.g., CPUDevice, GPUDevice.
256 //   Functor: defined above. E.g., functor::tanh_grad.
257 template <typename Device, typename Functor>
258 class SimpleBinaryOp : public OpKernel {
259  public:
260   typedef typename Functor::in_type Tin;    // Input scalar data type.
261   typedef typename Functor::out_type Tout;  // Output scalar data type.
262 
SimpleBinaryOp(OpKernelConstruction * ctx)263   explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
264 
Compute(OpKernelContext * ctx)265   void Compute(OpKernelContext* ctx) override {
266     const Tensor& in0 = ctx->input(0);
267     const Tensor& in1 = ctx->input(1);
268     OP_REQUIRES(
269         ctx, in0.NumElements() == in1.NumElements(),
270         errors::InvalidArgument("The two arguments to a cwise op must have "
271                                 "same number of elements, got ",
272                                 in0.NumElements(), " and ", in1.NumElements()));
273     auto in0_flat = in0.flat<Tin>();
274     auto in1_flat = in1.flat<Tin>();
275     const Device& eigen_device = ctx->eigen_device<Device>();
276 
277     Tensor* out = nullptr;
278     if (std::is_same<Tin, Tout>::value) {
279       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
280                               {0, 1}, 0, in0.shape(), &out));
281     } else {
282       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
283     }
284     auto out_flat = out->flat<Tout>();
285     functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
286                                                     in0_flat, in1_flat);
287   }
288 };
289 
290 // Coefficient-wise unary operations:
291 //   Device: E.g., CPUDevice, GPUDevice.
292 //   Functor: defined in cwise_ops.h. E.g., functor::sqrt.
293 template <typename Device, typename Functor>
294 class UnaryOp : public OpKernel {
295  public:
296   typedef typename Functor::in_type Tin;    // Input scalar data type.
297   typedef typename Functor::out_type Tout;  // Output scalar data type.
298   // Tin may be different from Tout. E.g., abs: complex64 -> float
299 
UnaryOp(OpKernelConstruction * ctx)300   explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
301     auto in = DataTypeToEnum<Tin>::v();
302     auto out = DataTypeToEnum<Tout>::v();
303     OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
304   }
305 
Compute(OpKernelContext * ctx)306   void Compute(OpKernelContext* ctx) override {
307     const Tensor& inp = ctx->input(0);
308     Tensor* out = nullptr;
309     if (std::is_same<Tin, Tout>::value) {
310       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
311                               {0}, 0, inp.shape(), &out));
312     } else {
313       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
314     }
315     functor::UnaryFunctor<Device, Functor>()(
316         ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
317   }
318 };
319 
320 template <typename Device, VariantUnaryOp OpEnum>
321 class UnaryVariantOp : public OpKernel {
322  public:
UnaryVariantOp(OpKernelConstruction * ctx)323   explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
324 
Compute(OpKernelContext * ctx)325   void Compute(OpKernelContext* ctx) override {
326     const Tensor& inp = ctx->input(0);
327     OP_REQUIRES(
328         ctx, TensorShapeUtils::IsScalar(inp.shape()),
329         errors::InvalidArgument("Non-scalar variants are not supported."));
330     const Variant& v = inp.scalar<Variant>()();
331     Variant v_out;
332     OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out));
333     int numa_node = ctx->device()->NumaNode();
334     Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape());
335     out.scalar<Variant>()() = std::move(v_out);
336     ctx->set_output(0, std::move(out));
337   }
338 };
339 
340 namespace functor {
341 
342 template <typename D, typename Out, typename Rhs>
Assign(const D & d,Out out,Rhs rhs)343 void Assign(const D& d, Out out, Rhs rhs) {
344   out.device(d) = rhs;
345 }
346 
347 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, NDIMS>
348 // for functors with no error checking.
349 template <typename Functor, int NDIMS>
350 struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
351   void operator()(const CPUDevice& d, typename Functor::tout_type out,
352                   typename Functor::tin_type in0,
353                   typename Functor::tin_type in1, bool* error) {
354     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
355   }
356 
357   void Left(const CPUDevice& d, typename Functor::tout_type out,
358             typename Functor::tscalar_type scalar,
359             typename Functor::tin_type in, bool* error) {
360     typedef typename Functor::out_type Tout;
361     typedef typename Functor::in_type Tin;
362     typedef typename Functor::func Binary;
363     typedef
364         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
365                                               /*is_scalar_in_host_memory=*/true>
366             Unary;
367     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
368   }
369 
370   void Right(const CPUDevice& d, typename Functor::tout_type out,
371              typename Functor::tin_type in,
372              typename Functor::tscalar_type scalar, bool* error) {
373     typedef typename Functor::out_type Tout;
374     typedef typename Functor::in_type Tin;
375     typedef typename Functor::func Binary;
376     typedef typename Eigen::internal::scalar_right<
377         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
378         Unary;
379     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
380   }
381 
382   void BCast(const CPUDevice& dev,
383              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
384              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
385              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
386              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
387              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
388              bool* error) {
389     typename Functor::func func;
390     if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
391       Assign(dev, out, in0.binaryExpr(in1, func));
392     } else if (AllOne<NDIMS>(bcast0)) {
393       auto rhs = in1.broadcast(bcast1);
394       Assign(dev, out, in0.binaryExpr(rhs, func));
395     } else if (AllOne<NDIMS>(bcast1)) {
396       auto lhs = in0.broadcast(bcast0);
397       Assign(dev, out, lhs.binaryExpr(in1, func));
398     } else {
399       auto lhs = in0.broadcast(bcast0);
400       auto rhs = in1.broadcast(bcast1);
401       Assign(dev, out, lhs.binaryExpr(rhs, func));
402     }
403   }
404 };
405 
406 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
407 // for functors with no error checking.
408 template <typename Functor>
409 struct BinaryFunctor<CPUDevice, Functor, 2, false> {
410   enum { NDIMS = 2 };
411 
412   void operator()(const CPUDevice& d, typename Functor::tout_type out,
413                   typename Functor::tin_type in0,
414                   typename Functor::tin_type in1, bool* error) {
415     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
416   }
417 
418   void Left(const CPUDevice& d, typename Functor::tout_type out,
419             typename Functor::tscalar_type scalar,
420             typename Functor::tin_type in, bool* error) {
421     typedef typename Functor::out_type Tout;
422     typedef typename Functor::in_type Tin;
423     typedef typename Functor::func Binary;
424     typedef
425         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
426                                               /*is_scalar_in_host_memory=*/true>
427             Unary;
428     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
429   }
430 
431   void Right(const CPUDevice& d, typename Functor::tout_type out,
432              typename Functor::tin_type in,
433              typename Functor::tscalar_type scalar, bool* error) {
434     typedef typename Functor::out_type Tout;
435     typedef typename Functor::in_type Tin;
436     typedef typename Functor::func Binary;
437     typedef typename Eigen::internal::scalar_right<
438         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
439         Unary;
440     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
441   }
442 
443 #if !defined(EIGEN_HAS_INDEX_LIST)
444   inline Eigen::DSizes<int, 2> NByOne(int n) {
445     return Eigen::DSizes<int, 2>(n, 1);
446   }
447   inline Eigen::DSizes<int, 2> OneByM(int m) {
448     return Eigen::DSizes<int, 2>(1, m);
449   }
450 #else
451   inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
452     Eigen::IndexList<int, Eigen::type2index<1>> ret;
453     ret.set(0, n);
454     return ret;
455   }
456   inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
457     Eigen::IndexList<Eigen::type2index<1>, int> ret;
458     ret.set(1, m);
459     return ret;
460   }
461 #endif
462 
463   void BCast(const CPUDevice& dev,
464              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
465              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
466              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
467              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
468              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
469              bool* error) {
470     typedef typename Functor::in_type T;
471     typename Functor::func func;
472     if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
473       // Optimize for speed by using Eigen::type2index and avoid
474       // .broadcast() when we know it's a no-op.
475       //
476       // Here, we need to handle 6 cases depending on how many "1"
477       // exist in in0 and in1's shapes (4 numbers in total). It's not
478       // possible that two shapes have more than 2 1s because those
479       // are simplified to NDIMS==1 case.
480       //
481       // Because this optimization increases the binary size for each
482       // Functor (+, -, *, /, <, <=, etc.), type and ndim combination.
483       // we only apply such optimization for selected ops/types/ndims.
484       //
485       // Because NDIMS, Functor::use_broadcast_optimization and
486       // use_broadcast_optimization<T> are compile-time constant, gcc
487       // does a decent job avoiding generating code when conditions
488       // are not met.
489       const int a = in0.dimension(0);  // in0 is shape [a, b]
490       const int b = in0.dimension(1);
491       const int c = in1.dimension(0);  // in1 is shape [c, d]
492       const int d = in1.dimension(1);
493       if ((a == 1) && (d == 1)) {
494         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
495         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
496         Assign(dev, out, lhs.binaryExpr(rhs, func));
497         return;
498       }
499       if ((b == 1) && (c == 1)) {
500         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
501         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
502         Assign(dev, out, lhs.binaryExpr(rhs, func));
503         return;
504       }
505       if (a == 1) {
506         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
507         auto rhs = in1;
508         Assign(dev, out, lhs.binaryExpr(rhs, func));
509         return;
510       }
511       if (b == 1) {
512         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
513         auto rhs = in1;
514         Assign(dev, out, lhs.binaryExpr(rhs, func));
515         return;
516       }
517       if (c == 1) {
518         auto lhs = in0;
519         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
520         Assign(dev, out, lhs.binaryExpr(rhs, func));
521         return;
522       }
523       if (d == 1) {
524         auto lhs = in0;
525         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
526         Assign(dev, out, lhs.binaryExpr(rhs, func));
527         return;
528       }
529 
530       const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
531       const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
532       if (bcast0_all_one && !bcast1_all_one) {
533         auto lhs = in0;  // No need to do broadcast for in0
534         auto rhs = in1.broadcast(bcast1);
535         Assign(dev, out, lhs.binaryExpr(rhs, func));
536         return;
537       }
538 
539       if (!bcast0_all_one && bcast1_all_one) {
540         auto lhs = in0.broadcast(bcast0);
541         auto rhs = in1;  // No need to do broadcast for in1
542         Assign(dev, out, lhs.binaryExpr(rhs, func));
543         return;
544       }
545     }
546 
547     // Fallback path. Always works and probably slower.
548     auto lhs = in0.broadcast(bcast0);
549     auto rhs = in1.broadcast(bcast1);
550     Assign(dev, out, lhs.binaryExpr(rhs, func));
551   }
552 };
553 
554 // Version of BinaryFunctor with error handling.
555 template <typename Functor, int NDIMS>
556 struct BinaryFunctor<CPUDevice, Functor, NDIMS, true> {
557   void operator()(const CPUDevice& d, typename Functor::tout_type out,
558                   typename Functor::tin_type in0,
559                   typename Functor::tin_type in1, bool* error) {
560     Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error)));
561   }
562 
563   void Left(const CPUDevice& d, typename Functor::tout_type out,
564             typename Functor::tscalar_type scalar,
565             typename Functor::tin_type in, bool* error) {
566     typedef typename Functor::out_type Tout;
567     typedef typename Functor::in_type Tin;
568     typedef typename Functor::func Binary;
569     typedef
570         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
571                                               /*is_scalar_in_host_memory=*/true>
572             Unary;
573     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
574   }
575 
576   void Right(const CPUDevice& d, typename Functor::tout_type out,
577              typename Functor::tin_type in,
578              typename Functor::tscalar_type scalar, bool* error) {
579     typedef typename Functor::out_type Tout;
580     typedef typename Functor::in_type Tin;
581     typedef typename Functor::func Binary;
582     typedef typename Eigen::internal::scalar_right<
583         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
584         Unary;
585     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
586   }
587 
588   void BCast(const CPUDevice& dev,
589              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
590              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
591              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
592              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
593              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
594              bool* error) {
595     typename Functor::func func(error);
596     auto lhs = in0.broadcast(bcast0);
597     auto rhs = in1.broadcast(bcast1);
598     Assign(dev, out, lhs.binaryExpr(rhs, func));
599   }
600 };
601 
602 // Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>.
603 template <typename Functor>
604 struct UnaryFunctor<CPUDevice, Functor> {
605   void operator()(const CPUDevice& d, typename Functor::tout_type out,
606                   typename Functor::tin_type in) {
607     Assign(d, out, in.unaryExpr(typename Functor::func()));
608   }
609 };
610 
611 // Partial specialization of ApproximateEqual<Device=CPUDevice, T>.
612 template <typename T>
613 struct ApproximateEqual<CPUDevice, T> {
614   void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat x,
615                   typename TTypes<T>::ConstFlat y, T tolerance,
616                   typename TTypes<bool>::Flat z) {
617     auto diff = x - y;
618     z.device(d) = diff.abs() <= tolerance;
619   }
620 };
621 
622 }  // end namespace functor
623 
624 #define REGISTER(OP, D, N, F, T)                                             \
625   REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
626                           OP<D##Device, F<T>>);
627 
628 #define REGISTER_VARIANT(OP, D, N, ENUM)                       \
629   REGISTER_KERNEL_BUILDER(                                     \
630       Name(N).Device(DEVICE_##D).TypeConstraint<Variant>("T"), \
631       OP<D##Device, ENUM>);
632 
633 // Macros to register kernels for multiple types (T0, T1, etc.)  on
634 // device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
635 // the functor "F" (e.g., functor::sqrt).
636 
637 #if defined(__ANDROID_TYPES_SLIM__)
638 // Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files.
639 // Normally Android TensorFlow is built with a reduced number of types (float).
640 // Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__"
641 // to generate a library with full type support with a consequent increase in
642 // code size.
643 #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
644 #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
645 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
646 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0)
647 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
648 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
649   REGISTER(OP, D, N, F, T0)
650 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
651   REGISTER(OP, D, N, F, T0)
652 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
653   REGISTER(OP, D, N, F, T0)
654 #else  // !defined(__ANDROID_TYPES_SLIM__)
655 #define REGISTER2(OP, D, N, F, T0, T1) \
656   REGISTER(OP, D, N, F, T0)            \
657   REGISTER(OP, D, N, F, T1)
658 #define REGISTER3(OP, D, N, F, T0, T1, T2) \
659   REGISTER2(OP, D, N, F, T0, T1)           \
660   REGISTER(OP, D, N, F, T2)
661 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
662   REGISTER2(OP, D, N, F, T0, T1)               \
663   REGISTER2(OP, D, N, F, T2, T3)
664 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
665   REGISTER3(OP, D, N, F, T0, T1, T2)               \
666   REGISTER2(OP, D, N, F, T3, T4)
667 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
668   REGISTER3(OP, D, N, F, T0, T1, T2)                   \
669   REGISTER3(OP, D, N, F, T3, T4, T5)
670 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
671   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                   \
672   REGISTER3(OP, D, N, F, T4, T5, T6)
673 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
674   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                       \
675   REGISTER4(OP, D, N, F, T4, T5, T6, T7)
676 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
677   REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4)                       \
678   REGISTER4(OP, D, N, F, T5, T6, T7, T8)
679 
680 // Instead of adding REGISTER10, etc., shard the .cc files - see
681 // cwise_op_equal_to_*.cc for an example.
682 
683 #endif  // defined(__ANDROID_TYPES_SLIM__)
684 
685 }  // end namespace tensorflow
686 
687 #endif  // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
688