• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
18 
19 // See docs in ../ops/math_ops.cc.
20 #define _USE_MATH_DEFINES
21 #include <cmath>
22 
23 #define EIGEN_USE_THREADS
24 
25 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
26 
27 #ifdef TENSORFLOW_USE_SYCL
28 #include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
29 #endif
30 
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/tensor_types.h"
34 #include "tensorflow/core/framework/variant_op_registry.h"
35 #include "tensorflow/core/kernels/cwise_ops.h"
36 #include "tensorflow/core/kernels/cwise_ops_gradients.h"
37 #include "tensorflow/core/kernels/fill_functor.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/util/bcast.h"
40 
41 namespace tensorflow {
42 
43 typedef Eigen::ThreadPoolDevice CPUDevice;
44 typedef Eigen::GpuDevice GPUDevice;
45 #ifdef TENSORFLOW_USE_SYCL
46 typedef Eigen::SyclDevice SYCLDevice;
47 #endif
48 
49 class BinaryOpShared : public OpKernel {
50  public:
51   explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in);
52 
53  protected:
54   struct BinaryOpState {
55     // Sets up bcast with the shape of in0 and in1, ensures that the bcast
56     // is valid, and if so, set out, either by allocating a new buffer using
57     // ctx->output(...) or by creating an alias for an owned input buffer for
58     // in-place computation.
59     // Caller must check ctx->status() upon return for non-ok status.
60     // If ctx->status().ok() is true, then out is guaranteed to be allocated.
61     explicit BinaryOpState(OpKernelContext* ctx);
62 
63     const Tensor& in0;
64     const Tensor& in1;
65 
66     BCast bcast;
67     Tensor* out = nullptr;
68     int64 out_num_elements;
69 
70     int64 in0_num_elements;
71     int64 in1_num_elements;
72 
73     int ndims;
74     bool result;
75   };
76 
77   void SetUnimplementedError(OpKernelContext* ctx);
78   void SetComputeError(OpKernelContext* ctx);
79 };
80 
81 // Coefficient-wise binary operations:
82 //   Device: E.g., CPUDevice, GPUDevice.
83 //   Functor: defined in cwise_ops.h. E.g., functor::add.
84 template <typename Device, typename Functor>
85 class BinaryOp : public BinaryOpShared {
86  public:
87   typedef typename Functor::in_type Tin;    // Input scalar data type.
88   typedef typename Functor::out_type Tout;  // Output scalar data type.
89 
BinaryOp(OpKernelConstruction * ctx)90   explicit BinaryOp(OpKernelConstruction* ctx)
91       : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(),
92                        DataTypeToEnum<Tin>::v()) {}
93 
Compute(OpKernelContext * ctx)94   void Compute(OpKernelContext* ctx) override {
95     const Tensor& input_0 = ctx->input(0);
96     const Tensor& input_1 = ctx->input(1);
97     const Device& eigen_device = ctx->eigen_device<Device>();
98     bool error = false;
99     bool* const error_ptr = Functor::has_errors ? &error : nullptr;
100 
101     // NOTE: Handle three simple cases before building the BinaryOpState, which
102     // is relatively expensive for small operations.
103     if (input_0.shape() == input_1.shape()) {
104       // tensor op tensor with no broadcasting.
105       Tensor* out;
106       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
107                               {0, 1}, 0, input_0.shape(), &out));
108       functor::BinaryFunctor<Device, Functor, 1>()(
109           eigen_device, out->template flat<Tout>(),
110           input_0.template flat<Tin>(), input_1.template flat<Tin>(),
111           error_ptr);
112       if (Functor::has_errors && error) {
113         SetComputeError(ctx);
114       }
115       return;
116     } else if (input_0.shape().dims() == 0) {
117       // scalar op tensor.
118       Tensor* out;
119       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
120                               {1}, 0, input_1.shape(), &out));
121 
122       functor::BinaryFunctor<Device, Functor, 1>().Left(
123           eigen_device, out->template flat<Tout>(),
124           input_0.template scalar<Tin>(), input_1.template flat<Tin>(),
125           error_ptr);
126       if (Functor::has_errors && error) {
127         SetComputeError(ctx);
128       }
129       return;
130     } else if (input_1.shape().dims() == 0) {
131       // tensor op scalar.
132       Tensor* out;
133       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
134                               {0}, 0, input_0.shape(), &out));
135       functor::BinaryFunctor<Device, Functor, 1>().Right(
136           eigen_device, out->template flat<Tout>(),
137           input_0.template flat<Tin>(), input_1.template scalar<Tin>(),
138           error_ptr);
139       if (Functor::has_errors && error) {
140         SetComputeError(ctx);
141       }
142       return;
143     }
144 
145     // 'state': Shared helper not dependent on T to reduce code size
146     BinaryOpState state(ctx);
147     if (ctx->status().code() == error::RESOURCE_EXHAUSTED) {
148       // Stop when BinaryOpState's constructor failed due to OOM.
149       return;
150     }
151     auto& bcast = state.bcast;
152     Tensor* out = state.out;
153     if (!bcast.IsValid()) {
154       if (ctx->status().ok()) {
155         if (state.result) {
156           functor::SetOneFunctor<Device, bool>()(eigen_device,
157                                                  out->flat<bool>());
158         } else {
159           functor::SetZeroFunctor<Device, bool>()(eigen_device,
160                                                   out->flat<bool>());
161         }
162       }
163       return;
164     }
165 
166     auto& in0 = state.in0;
167     auto& in1 = state.in1;
168     if (state.out_num_elements == 0) {
169       return;
170     }
171 
172     const int ndims = state.ndims;
173     if (ndims <= 1) {
174       auto out_flat = out->flat<Tout>();
175       if (state.in1_num_elements == 1) {
176         // tensor op scalar
177         functor::BinaryFunctor<Device, Functor, 1>().Right(
178             eigen_device, out_flat, in0.template flat<Tin>(),
179             in1.template scalar<Tin>(), error_ptr);
180       } else if (state.in0_num_elements == 1) {
181         // scalar op tensor
182         functor::BinaryFunctor<Device, Functor, 1>().Left(
183             eigen_device, out_flat, in0.template scalar<Tin>(),
184             in1.template flat<Tin>(), error_ptr);
185       } else {
186         functor::BinaryFunctor<Device, Functor, 1>()(
187             eigen_device, out_flat, in0.template flat<Tin>(),
188             in1.template flat<Tin>(), error_ptr);
189       }
190     } else if (ndims == 2) {
191       functor::BinaryFunctor<Device, Functor, 2>().BCast(
192           eigen_device, out->shaped<Tout, 2>(bcast.result_shape()),
193           in0.template shaped<Tin, 2>(bcast.x_reshape()),
194           BCast::ToIndexArray<2>(bcast.x_bcast()),
195           in1.template shaped<Tin, 2>(bcast.y_reshape()),
196           BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr);
197     } else if (ndims == 3) {
198       functor::BinaryFunctor<Device, Functor, 3>().BCast(
199           eigen_device, out->shaped<Tout, 3>(bcast.result_shape()),
200           in0.template shaped<Tin, 3>(bcast.x_reshape()),
201           BCast::ToIndexArray<3>(bcast.x_bcast()),
202           in1.template shaped<Tin, 3>(bcast.y_reshape()),
203           BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr);
204     } else if (ndims == 4) {
205       functor::BinaryFunctor<Device, Functor, 4>().BCast(
206           eigen_device, out->shaped<Tout, 4>(bcast.result_shape()),
207           in0.template shaped<Tin, 4>(bcast.x_reshape()),
208           BCast::ToIndexArray<4>(bcast.x_bcast()),
209           in1.template shaped<Tin, 4>(bcast.y_reshape()),
210           BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr);
211     } else if (ndims == 5) {
212       functor::BinaryFunctor<Device, Functor, 5>().BCast(
213           eigen_device, out->shaped<Tout, 5>(bcast.result_shape()),
214           in0.template shaped<Tin, 5>(bcast.x_reshape()),
215           BCast::ToIndexArray<5>(bcast.x_bcast()),
216           in1.template shaped<Tin, 5>(bcast.y_reshape()),
217           BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr);
218     } else {
219       SetUnimplementedError(ctx);
220     }
221     if (Functor::has_errors && error) {
222       SetComputeError(ctx);
223     }
224   }
225 };
226 
227 template <typename Device, typename T>
228 class ApproximateEqualOp : public OpKernel {
229  public:
ApproximateEqualOp(OpKernelConstruction * context)230   explicit ApproximateEqualOp(OpKernelConstruction* context)
231       : OpKernel(context) {
232     float tolerance;
233     OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance));
234     tolerance_ = T(tolerance);
235   }
Compute(OpKernelContext * context)236   void Compute(OpKernelContext* context) override {
237     const Tensor& x_input = context->input(0);
238     const Tensor& y_input = context->input(1);
239     OP_REQUIRES(
240         context, x_input.shape() == y_input.shape(),
241         errors::InvalidArgument("x and y must be of the same shape. ",
242                                 "x shape: ", x_input.shape().DebugString(),
243                                 ". y shape: ", y_input.shape().DebugString()));
244     Tensor* z_output = nullptr;
245     OP_REQUIRES_OK(context,
246                    context->allocate_output(0, x_input.shape(), &z_output));
247     const Device& d = context->eigen_device<Device>();
248     typename TTypes<T>::ConstFlat x(x_input.flat<T>());
249     typename TTypes<T>::ConstFlat y(y_input.flat<T>());
250     typename TTypes<bool>::Flat z(z_output->flat<bool>());
251     functor::ApproximateEqual<Device, T>()(d, x, y, tolerance_, z);
252   }
253 
254  private:
255   T tolerance_;
256 };
257 
258 // Basic coefficient-wise binary operations that are known to not require
259 // any broadcasting. This is the case for example of the gradients of
260 // unary operations.
261 //   Device: E.g., CPUDevice, GPUDevice.
262 //   Functor: defined above. E.g., functor::tanh_grad.
263 template <typename Device, typename Functor>
264 class SimpleBinaryOp : public OpKernel {
265  public:
266   typedef typename Functor::in_type Tin;    // Input scalar data type.
267   typedef typename Functor::out_type Tout;  // Output scalar data type.
268 
SimpleBinaryOp(OpKernelConstruction * ctx)269   explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
270 
Compute(OpKernelContext * ctx)271   void Compute(OpKernelContext* ctx) override {
272     const Tensor& in0 = ctx->input(0);
273     const Tensor& in1 = ctx->input(1);
274     auto in0_flat = in0.flat<Tin>();
275     auto in1_flat = in1.flat<Tin>();
276     const Device& eigen_device = ctx->eigen_device<Device>();
277 
278     Tensor* out = nullptr;
279     if (std::is_same<Tin, Tout>::value) {
280       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
281                               {0, 1}, 0, in0.shape(), &out));
282     } else {
283       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
284     }
285     auto out_flat = out->flat<Tout>();
286     functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
287                                                     in0_flat, in1_flat);
288   }
289 };
290 
291 // Coefficient-wise unary operations:
292 //   Device: E.g., CPUDevice, GPUDevice.
293 //   Functor: defined in cwise_ops.h. E.g., functor::sqrt.
294 template <typename Device, typename Functor>
295 class UnaryOp : public OpKernel {
296  public:
297   typedef typename Functor::in_type Tin;    // Input scalar data type.
298   typedef typename Functor::out_type Tout;  // Output scalar data type.
299   // Tin may be different from Tout. E.g., abs: complex64 -> float
300 
UnaryOp(OpKernelConstruction * ctx)301   explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
302     auto in = DataTypeToEnum<Tin>::v();
303     auto out = DataTypeToEnum<Tout>::v();
304     OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
305   }
306 
Compute(OpKernelContext * ctx)307   void Compute(OpKernelContext* ctx) override {
308     const Tensor& inp = ctx->input(0);
309     Tensor* out = nullptr;
310     if (std::is_same<Tin, Tout>::value) {
311       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
312                               {0}, 0, inp.shape(), &out));
313     } else {
314       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
315     }
316     functor::UnaryFunctor<Device, Functor>()(
317         ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
318   }
319 };
320 
321 template <typename Device, VariantUnaryOp OpEnum>
322 class UnaryVariantOp : public OpKernel {
323  public:
UnaryVariantOp(OpKernelConstruction * ctx)324   explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
325 
Compute(OpKernelContext * ctx)326   void Compute(OpKernelContext* ctx) override {
327     const Tensor& inp = ctx->input(0);
328     OP_REQUIRES(
329         ctx, TensorShapeUtils::IsScalar(inp.shape()),
330         errors::InvalidArgument("Non-scalar variants are not supported."));
331     const Variant& v = inp.scalar<Variant>()();
332     Variant v_out;
333     OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out));
334     int numa_node = ctx->device()->NumaNode();
335     Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape());
336     out.scalar<Variant>()() = std::move(v_out);
337     ctx->set_output(0, std::move(out));
338   }
339 };
340 
341 namespace functor {
342 
343 template <typename D, typename Out, typename Rhs>
Assign(const D & d,Out out,Rhs rhs)344 void Assign(const D& d, Out out, Rhs rhs) {
345   out.device(d) = rhs;
346 }
347 
348 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, NDIMS>
349 // for functors with with no error checking.
350 template <typename Functor, int NDIMS>
351 struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
352   void operator()(const CPUDevice& d, typename Functor::tout_type out,
353                   typename Functor::tin_type in0,
354                   typename Functor::tin_type in1, bool* error) {
355     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
356   }
357 
358   void Left(const CPUDevice& d, typename Functor::tout_type out,
359             typename Functor::tscalar_type scalar,
360             typename Functor::tin_type in, bool* error) {
361     typedef typename Functor::out_type Tout;
362     typedef typename Functor::in_type Tin;
363     typedef typename Functor::func Binary;
364     typedef
365         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
366                                               /*is_scalar_in_host_memory=*/true>
367             Unary;
368     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
369   }
370 
371   void Right(const CPUDevice& d, typename Functor::tout_type out,
372              typename Functor::tin_type in,
373              typename Functor::tscalar_type scalar, bool* error) {
374     typedef typename Functor::out_type Tout;
375     typedef typename Functor::in_type Tin;
376     typedef typename Functor::func Binary;
377     typedef typename Eigen::internal::scalar_right<
378         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
379         Unary;
380     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
381   }
382 
383   void BCast(const CPUDevice& dev,
384              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
385              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
386              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
387              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
388              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
389              bool* error) {
390     typename Functor::func func;
391     if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
392       Assign(dev, out, in0.binaryExpr(in1, func));
393     } else if (AllOne<NDIMS>(bcast0)) {
394       auto rhs = in1.broadcast(bcast1);
395       Assign(dev, out, in0.binaryExpr(rhs, func));
396     } else if (AllOne<NDIMS>(bcast1)) {
397       auto lhs = in0.broadcast(bcast0);
398       Assign(dev, out, lhs.binaryExpr(in1, func));
399     } else {
400       auto lhs = in0.broadcast(bcast0);
401       auto rhs = in1.broadcast(bcast1);
402       Assign(dev, out, lhs.binaryExpr(rhs, func));
403     }
404   }
405 };
406 
407 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
408 // for functors with with no error checking.
409 template <typename Functor>
410 struct BinaryFunctor<CPUDevice, Functor, 2, false> {
411   enum { NDIMS = 2 };
412 
413   void operator()(const CPUDevice& d, typename Functor::tout_type out,
414                   typename Functor::tin_type in0,
415                   typename Functor::tin_type in1, bool* error) {
416     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
417   }
418 
419   void Left(const CPUDevice& d, typename Functor::tout_type out,
420             typename Functor::tscalar_type scalar,
421             typename Functor::tin_type in, bool* error) {
422     typedef typename Functor::out_type Tout;
423     typedef typename Functor::in_type Tin;
424     typedef typename Functor::func Binary;
425     typedef
426         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
427                                               /*is_scalar_in_host_memory=*/true>
428             Unary;
429     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
430   }
431 
432   void Right(const CPUDevice& d, typename Functor::tout_type out,
433              typename Functor::tin_type in,
434              typename Functor::tscalar_type scalar, bool* error) {
435     typedef typename Functor::out_type Tout;
436     typedef typename Functor::in_type Tin;
437     typedef typename Functor::func Binary;
438     typedef typename Eigen::internal::scalar_right<
439         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
440         Unary;
441     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
442   }
443 
444 #if !defined(EIGEN_HAS_INDEX_LIST)
445   inline Eigen::DSizes<int, 2> NByOne(int n) {
446     return Eigen::DSizes<int, 2>(n, 1);
447   }
448   inline Eigen::DSizes<int, 2> OneByM(int m) {
449     return Eigen::DSizes<int, 2>(1, m);
450   }
451 #else
452   inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
453     Eigen::IndexList<int, Eigen::type2index<1>> ret;
454     ret.set(0, n);
455     return ret;
456   }
457   inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
458     Eigen::IndexList<Eigen::type2index<1>, int> ret;
459     ret.set(1, m);
460     return ret;
461   }
462 #endif
463 
464   void BCast(const CPUDevice& dev,
465              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
466              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
467              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
468              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
469              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
470              bool* error) {
471     typedef typename Functor::in_type T;
472     typename Functor::func func;
473     if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
474       // Optimize for speed by using Eigen::type2index and avoid
475       // .broadcast() when we know its a no-op.
476       //
477       // Here, we need to handle 6 cases depending on how many "1"
478       // exist in in0 and in1's shapes (4 numbers in total). It's not
479       // possible that two shapes have more than 2 1s because those
480       // are simplified to NDIMS==1 case.
481       //
482       // Because this optimization increases the binary size for each
483       // Functor (+, -, *, /, <, <=, etc.), type and ndim combination.
484       // we only apply such optimization for selected ops/types/ndims.
485       //
486       // Because NDIMS, Functor::use_broadcast_optimization and
487       // use_broadcast_optimization<T> are compile-time constant, gcc
488       // does a decent job avoiding generating code when conditions
489       // are not met.
490       const int a = in0.dimension(0);  // in0 is shape [a, b]
491       const int b = in0.dimension(1);
492       const int c = in1.dimension(0);  // in1 is shape [c, d]
493       const int d = in1.dimension(1);
494       if ((a == 1) && (d == 1)) {
495         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
496         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
497         Assign(dev, out, lhs.binaryExpr(rhs, func));
498         return;
499       }
500       if ((b == 1) && (c == 1)) {
501         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
502         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
503         Assign(dev, out, lhs.binaryExpr(rhs, func));
504         return;
505       }
506       if (a == 1) {
507         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
508         auto rhs = in1;
509         Assign(dev, out, lhs.binaryExpr(rhs, func));
510         return;
511       }
512       if (b == 1) {
513         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
514         auto rhs = in1;
515         Assign(dev, out, lhs.binaryExpr(rhs, func));
516         return;
517       }
518       if (c == 1) {
519         auto lhs = in0;
520         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
521         Assign(dev, out, lhs.binaryExpr(rhs, func));
522         return;
523       }
524       if (d == 1) {
525         auto lhs = in0;
526         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
527         Assign(dev, out, lhs.binaryExpr(rhs, func));
528         return;
529       }
530 
531       const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
532       const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
533       if (bcast0_all_one && !bcast1_all_one) {
534         auto lhs = in0;  // No need to do broadcast for in0
535         auto rhs = in1.broadcast(bcast1);
536         Assign(dev, out, lhs.binaryExpr(rhs, func));
537         return;
538       }
539 
540       if (!bcast0_all_one && bcast1_all_one) {
541         auto lhs = in0.broadcast(bcast0);
542         auto rhs = in1;  // No need to do broadcast for in1
543         Assign(dev, out, lhs.binaryExpr(rhs, func));
544         return;
545       }
546     }
547 
548     // Fallback path. Always works and probably slower.
549     auto lhs = in0.broadcast(bcast0);
550     auto rhs = in1.broadcast(bcast1);
551     Assign(dev, out, lhs.binaryExpr(rhs, func));
552   }
553 };
554 
555 // Version of BinaryFunctor with error handling.
556 template <typename Functor, int NDIMS>
557 struct BinaryFunctor<CPUDevice, Functor, NDIMS, true> {
558   void operator()(const CPUDevice& d, typename Functor::tout_type out,
559                   typename Functor::tin_type in0,
560                   typename Functor::tin_type in1, bool* error) {
561     Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error)));
562   }
563 
564   void Left(const CPUDevice& d, typename Functor::tout_type out,
565             typename Functor::tscalar_type scalar,
566             typename Functor::tin_type in, bool* error) {
567     typedef typename Functor::out_type Tout;
568     typedef typename Functor::in_type Tin;
569     typedef typename Functor::func Binary;
570     typedef
571         typename Eigen::internal::scalar_left<Tout, Tin, Binary,
572                                               /*is_scalar_in_host_memory=*/true>
573             Unary;
574     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
575   }
576 
577   void Right(const CPUDevice& d, typename Functor::tout_type out,
578              typename Functor::tin_type in,
579              typename Functor::tscalar_type scalar, bool* error) {
580     typedef typename Functor::out_type Tout;
581     typedef typename Functor::in_type Tin;
582     typedef typename Functor::func Binary;
583     typedef typename Eigen::internal::scalar_right<
584         Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
585         Unary;
586     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
587   }
588 
589   void BCast(const CPUDevice& dev,
590              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
591              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
592              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
593              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
594              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
595              bool* error) {
596     typename Functor::func func(error);
597     auto lhs = in0.broadcast(bcast0);
598     auto rhs = in1.broadcast(bcast1);
599     Assign(dev, out, lhs.binaryExpr(rhs, func));
600   }
601 };
602 
603 // Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>.
604 template <typename Functor>
605 struct UnaryFunctor<CPUDevice, Functor> {
606   void operator()(const CPUDevice& d, typename Functor::tout_type out,
607                   typename Functor::tin_type in) {
608     Assign(d, out, in.unaryExpr(typename Functor::func()));
609   }
610 };
611 
612 // Partial specialization of ApproximateEqual<Device=CPUDevice, T>.
613 template <typename T>
614 struct ApproximateEqual<CPUDevice, T> {
615   void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat x,
616                   typename TTypes<T>::ConstFlat y, T tolerance,
617                   typename TTypes<bool>::Flat z) {
618     auto diff = x - y;
619     z.device(d) = diff.abs() <= tolerance;
620   }
621 };
622 
623 }  // end namespace functor
624 
625 #define REGISTER(OP, D, N, F, T)                                             \
626   REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
627                           OP<D##Device, F<T>>);
628 
629 #define REGISTER_VARIANT(OP, D, N, ENUM)                       \
630   REGISTER_KERNEL_BUILDER(                                     \
631       Name(N).Device(DEVICE_##D).TypeConstraint<Variant>("T"), \
632       OP<D##Device, ENUM>);
633 
634 // Macros to register kernels for multiple types (T0, T1, etc.)  on
635 // device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
636 // the functor "F" (e.g., functor::sqrt).
637 
638 #if defined(__ANDROID_TYPES_SLIM__)
639 // Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files.
640 // Normally Android TensorFlow is built with a reduced number of types (float).
641 // Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__"
642 // to generate a library with full type support with a consequent increase in
643 // code size.
644 #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
645 #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
646 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
647 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0)
648 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
649 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
650   REGISTER(OP, D, N, F, T0)
651 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
652   REGISTER(OP, D, N, F, T0)
653 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
654   REGISTER(OP, D, N, F, T0)
655 #else  // !defined(__ANDROID_TYPES_SLIM__)
656 #define REGISTER2(OP, D, N, F, T0, T1) \
657   REGISTER(OP, D, N, F, T0)            \
658   REGISTER(OP, D, N, F, T1)
659 #define REGISTER3(OP, D, N, F, T0, T1, T2) \
660   REGISTER2(OP, D, N, F, T0, T1)           \
661   REGISTER(OP, D, N, F, T2)
662 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
663   REGISTER2(OP, D, N, F, T0, T1)               \
664   REGISTER2(OP, D, N, F, T2, T3)
665 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
666   REGISTER3(OP, D, N, F, T0, T1, T2)               \
667   REGISTER2(OP, D, N, F, T3, T4)
668 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
669   REGISTER3(OP, D, N, F, T0, T1, T2)                   \
670   REGISTER3(OP, D, N, F, T3, T4, T5)
671 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
672   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                   \
673   REGISTER3(OP, D, N, F, T4, T5, T6)
674 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
675   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                       \
676   REGISTER4(OP, D, N, F, T4, T5, T6, T7)
677 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
678   REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4)                       \
679   REGISTER4(OP, D, N, F, T5, T6, T7, T8)
680 
681 // Instead of adding REGISTER10, etc., shard the .cc files - see
682 // cwise_op_equal_to_*.cc for an example.
683 
684 #endif  // defined(__ANDROID_TYPES_SLIM__)
685 
686 }  // end namespace tensorflow
687 
688 #endif  // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
689