• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
18 
19 // See docs in ../ops/math_ops.cc.
20 
21 #define EIGEN_USE_THREADS
22 
23 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
24 
25 #ifdef TENSORFLOW_USE_SYCL
26 #include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
27 #endif
28 
29 #include "tensorflow/core/kernels/cwise_ops.h"
30 #include "tensorflow/core/kernels/cwise_ops_gradients.h"
31 
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/variant_op_registry.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/util/bcast.h"
38 
39 namespace tensorflow {
40 
41 typedef Eigen::ThreadPoolDevice CPUDevice;
42 typedef Eigen::GpuDevice GPUDevice;
43 #ifdef TENSORFLOW_USE_SYCL
44 typedef Eigen::SyclDevice SYCLDevice;
45 #endif
46 
47 class BinaryOpShared : public OpKernel {
48  public:
49   explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in);
50 
51  protected:
52   struct BinaryOpState {
53     // Sets up bcast with the shape of in0 and in1, ensures that the bcast
54     // is valid, and if so, set out, either by allocating a new buffer using
55     // ctx->output(...) or by creating an alias for an owned input buffer for
56     // in-place computation.
57     // Caller must check ctx->status() upon return for non-ok status.
58     // If ctx->status().ok() is true, then out is guaranteed to be allocated.
59     BinaryOpState(OpKernelContext* ctx);
60 
61     const Tensor& in0;
62     const Tensor& in1;
63 
64     BCast bcast;
65     Tensor* out = nullptr;
66     int64 out_num_elements;
67 
68     int64 in0_num_elements;
69     int64 in1_num_elements;
70 
71     int ndims;
72   };
73 
74   void SetUnimplementedError(OpKernelContext* ctx);
75   void SetComputeError(OpKernelContext* ctx);
76 };
77 
78 // Coefficient-wise binary operations:
79 //   Device: E.g., CPUDevice, GPUDevice.
80 //   Functor: defined in cwise_ops.h. E.g., functor::add.
81 template <typename Device, typename Functor>
82 class BinaryOp : public BinaryOpShared {
83  public:
84   typedef typename Functor::in_type Tin;    // Input scalar data type.
85   typedef typename Functor::out_type Tout;  // Output scalar data type.
86 
BinaryOp(OpKernelConstruction * ctx)87   explicit BinaryOp(OpKernelConstruction* ctx)
88       : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(),
89                        DataTypeToEnum<Tin>::v()) {}
90 
Compute(OpKernelContext * ctx)91   void Compute(OpKernelContext* ctx) override {
92     // 'state': Shared helper not dependent on T to reduce code size
93     BinaryOpState state(ctx);
94     if (!ctx->status().ok()) return;
95     Tensor* out = state.out;
96     BCast* bcast = &state.bcast;
97     auto& in0 = state.in0;
98     auto& in1 = state.in1;
99     if (state.out_num_elements == 0) {
100       return;
101     }
102     const int ndims = state.ndims;
103     const Device& eigen_device = ctx->eigen_device<Device>();
104     bool error = false;
105     bool* const error_ptr = Functor::has_errors ? &error : nullptr;
106     if (ndims <= 1) {
107       auto out_flat = out->flat<Tout>();
108       if (state.in1_num_elements == 1) {
109         // tensor op scalar
110         functor::BinaryFunctor<Device, Functor, 1>().Right(
111             eigen_device, out_flat, in0.template flat<Tin>(),
112             in1.template scalar<Tin>(), error_ptr);
113       } else if (state.in0_num_elements == 1) {
114         // scalar op tensor
115         functor::BinaryFunctor<Device, Functor, 1>().Left(
116             eigen_device, out_flat, in0.template scalar<Tin>(),
117             in1.template flat<Tin>(), error_ptr);
118       } else {
119         functor::BinaryFunctor<Device, Functor, 1>()(
120             eigen_device, out_flat, in0.template flat<Tin>(),
121             in1.template flat<Tin>(), error_ptr);
122       }
123     } else if (ndims == 2) {
124       functor::BinaryFunctor<Device, Functor, 2>().BCast(
125           eigen_device, out->shaped<Tout, 2>(bcast->result_shape()),
126           in0.template shaped<Tin, 2>(bcast->x_reshape()),
127           BCast::ToIndexArray<2>(bcast->x_bcast()),
128           in1.template shaped<Tin, 2>(bcast->y_reshape()),
129           BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr);
130     } else if (ndims == 3) {
131       functor::BinaryFunctor<Device, Functor, 3>().BCast(
132           eigen_device, out->shaped<Tout, 3>(bcast->result_shape()),
133           in0.template shaped<Tin, 3>(bcast->x_reshape()),
134           BCast::ToIndexArray<3>(bcast->x_bcast()),
135           in1.template shaped<Tin, 3>(bcast->y_reshape()),
136           BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr);
137     } else if (ndims == 4) {
138       functor::BinaryFunctor<Device, Functor, 4>().BCast(
139           eigen_device, out->shaped<Tout, 4>(bcast->result_shape()),
140           in0.template shaped<Tin, 4>(bcast->x_reshape()),
141           BCast::ToIndexArray<4>(bcast->x_bcast()),
142           in1.template shaped<Tin, 4>(bcast->y_reshape()),
143           BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr);
144     } else if (ndims == 5) {
145       functor::BinaryFunctor<Device, Functor, 5>().BCast(
146           eigen_device, out->shaped<Tout, 5>(bcast->result_shape()),
147           in0.template shaped<Tin, 5>(bcast->x_reshape()),
148           BCast::ToIndexArray<5>(bcast->x_bcast()),
149           in1.template shaped<Tin, 5>(bcast->y_reshape()),
150           BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr);
151     } else {
152       SetUnimplementedError(ctx);
153     }
154     if (Functor::has_errors && error) {
155       SetComputeError(ctx);
156     }
157   }
158 };
159 
160 template <typename Device, typename T>
161 class ApproximateEqualOp : public OpKernel {
162  public:
ApproximateEqualOp(OpKernelConstruction * context)163   explicit ApproximateEqualOp(OpKernelConstruction* context)
164       : OpKernel(context) {
165     float tolerance;
166     OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance));
167     tolerance_ = T(tolerance);
168   }
Compute(OpKernelContext * context)169   void Compute(OpKernelContext* context) override {
170     const Tensor& x_input = context->input(0);
171     const Tensor& y_input = context->input(1);
172     OP_REQUIRES(
173         context, x_input.shape() == y_input.shape(),
174         errors::InvalidArgument("x and y must be of the same shape. ",
175                                 "x shape: ", x_input.shape().DebugString(),
176                                 ". y shape: ", y_input.shape().DebugString()));
177     Tensor* z_output = nullptr;
178     OP_REQUIRES_OK(context,
179                    context->allocate_output(0, x_input.shape(), &z_output));
180     const Device& d = context->eigen_device<Device>();
181     typename TTypes<T>::ConstFlat x(x_input.flat<T>());
182     typename TTypes<T>::ConstFlat y(y_input.flat<T>());
183     typename TTypes<bool>::Flat z(z_output->flat<bool>());
184     functor::ApproximateEqual<Device, T>()(d, x, y, tolerance_, z);
185   }
186 
187  private:
188   T tolerance_;
189 };
190 
191 // Basic coefficient-wise binary operations that are known to not require
192 // any broadcasting. This is the case for example of the gradients of
193 // unary operations.
194 //   Device: E.g., CPUDevice, GPUDevice.
195 //   Functor: defined above. E.g., functor::tanh_grad.
196 template <typename Device, typename Functor>
197 class SimpleBinaryOp : public OpKernel {
198  public:
199   typedef typename Functor::in_type Tin;    // Input scalar data type.
200   typedef typename Functor::out_type Tout;  // Output scalar data type.
201 
SimpleBinaryOp(OpKernelConstruction * ctx)202   explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
203 
Compute(OpKernelContext * ctx)204   void Compute(OpKernelContext* ctx) override {
205     const Tensor& in0 = ctx->input(0);
206     const Tensor& in1 = ctx->input(1);
207     auto in0_flat = in0.flat<Tin>();
208     auto in1_flat = in1.flat<Tin>();
209     const Device& eigen_device = ctx->eigen_device<Device>();
210 
211     Tensor* out = nullptr;
212     if (std::is_same<Tin, Tout>::value) {
213       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
214                               {0, 1}, 0, in0.shape(), &out));
215     } else {
216       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
217     }
218     auto out_flat = out->flat<Tout>();
219     functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
220                                                     in0_flat, in1_flat);
221   }
222 };
223 
224 // Coefficient-wise unary operations:
225 //   Device: E.g., CPUDevice, GPUDevice.
226 //   Functor: defined in cwise_ops.h. E.g., functor::sqrt.
227 template <typename Device, typename Functor>
228 class UnaryOp : public OpKernel {
229  public:
230   typedef typename Functor::in_type Tin;    // Input scalar data type.
231   typedef typename Functor::out_type Tout;  // Output scalar data type.
232   // Tin may be different from Tout. E.g., abs: complex64 -> float
233 
UnaryOp(OpKernelConstruction * ctx)234   explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
235     auto in = DataTypeToEnum<Tin>::v();
236     auto out = DataTypeToEnum<Tout>::v();
237     OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
238   }
239 
Compute(OpKernelContext * ctx)240   void Compute(OpKernelContext* ctx) override {
241     const Tensor& inp = ctx->input(0);
242     Tensor* out = nullptr;
243     if (std::is_same<Tin, Tout>::value) {
244       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
245                               {0}, 0, inp.shape(), &out));
246     } else {
247       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
248     }
249     functor::UnaryFunctor<Device, Functor>()(
250         ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
251   }
252 };
253 
254 template <typename Device, VariantUnaryOp OpEnum>
255 class UnaryVariantOp : public OpKernel {
256  public:
UnaryVariantOp(OpKernelConstruction * ctx)257   explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
258 
Compute(OpKernelContext * ctx)259   void Compute(OpKernelContext* ctx) override {
260     const Tensor& inp = ctx->input(0);
261     OP_REQUIRES(
262         ctx, TensorShapeUtils::IsScalar(inp.shape()),
263         errors::InvalidArgument("Non-scalar variants are not supported."));
264     const Variant& v = inp.scalar<Variant>()();
265     Variant v_out;
266     OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out));
267     int numa_node = DeviceNumaNode(ctx->device());
268     Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape());
269     out.scalar<Variant>()() = std::move(v_out);
270     ctx->set_output(0, std::move(out));
271   }
272 };
273 
274 namespace functor {
275 
276 template <typename D, typename Out, typename Rhs>
Assign(const D & d,Out out,Rhs rhs)277 void Assign(const D& d, Out out, Rhs rhs) {
278   out.device(d) = rhs;
279 }
280 
281 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, NDIMS>
282 // for functors with with no error checking.
283 template <typename Functor, int NDIMS>
284 struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
285   void operator()(const CPUDevice& d, typename Functor::tout_type out,
286                   typename Functor::tin_type in0,
287                   typename Functor::tin_type in1, bool* error) {
288     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
289   }
290 
291   void Left(const CPUDevice& d, typename Functor::tout_type out,
292             typename Functor::tscalar_type scalar,
293             typename Functor::tin_type in, bool* error) {
294     typedef typename Functor::out_type Tout;
295     typedef typename Functor::in_type Tin;
296     typedef typename Functor::func Binary;
297     typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
298     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
299   }
300 
301   void Right(const CPUDevice& d, typename Functor::tout_type out,
302              typename Functor::tin_type in,
303              typename Functor::tscalar_type scalar, bool* error) {
304     typedef typename Functor::out_type Tout;
305     typedef typename Functor::in_type Tin;
306     typedef typename Functor::func Binary;
307     typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
308     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
309   }
310 
311   void BCast(const CPUDevice& dev,
312              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
313              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
314              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
315              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
316              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
317              bool* error) {
318     typename Functor::func func;
319     if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
320       Assign(dev, out, in0.binaryExpr(in1, func));
321     } else if (AllOne<NDIMS>(bcast0)) {
322       auto rhs = in1.broadcast(bcast1);
323       Assign(dev, out, in0.binaryExpr(rhs, func));
324     } else if (AllOne<NDIMS>(bcast1)) {
325       auto lhs = in0.broadcast(bcast0);
326       Assign(dev, out, lhs.binaryExpr(in1, func));
327     } else {
328       auto lhs = in0.broadcast(bcast0);
329       auto rhs = in1.broadcast(bcast1);
330       Assign(dev, out, lhs.binaryExpr(rhs, func));
331     }
332   }
333 };
334 
335 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
336 // for functors with with no error checking.
337 template <typename Functor>
338 struct BinaryFunctor<CPUDevice, Functor, 2, false> {
339   enum { NDIMS = 2 };
340 
341   void operator()(const CPUDevice& d, typename Functor::tout_type out,
342                   typename Functor::tin_type in0,
343                   typename Functor::tin_type in1, bool* error) {
344     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
345   }
346 
347   void Left(const CPUDevice& d, typename Functor::tout_type out,
348             typename Functor::tscalar_type scalar,
349             typename Functor::tin_type in, bool* error) {
350     typedef typename Functor::out_type Tout;
351     typedef typename Functor::in_type Tin;
352     typedef typename Functor::func Binary;
353     typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
354     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
355   }
356 
357   void Right(const CPUDevice& d, typename Functor::tout_type out,
358              typename Functor::tin_type in,
359              typename Functor::tscalar_type scalar, bool* error) {
360     typedef typename Functor::out_type Tout;
361     typedef typename Functor::in_type Tin;
362     typedef typename Functor::func Binary;
363     typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
364     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
365   }
366 
367 #if !defined(EIGEN_HAS_INDEX_LIST)
368   inline Eigen::DSizes<int, 2> NByOne(int n) {
369     return Eigen::DSizes<int, 2>(n, 1);
370   }
371   inline Eigen::DSizes<int, 2> OneByM(int m) {
372     return Eigen::DSizes<int, 2>(1, m);
373   }
374 #else
375   inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
376     Eigen::IndexList<int, Eigen::type2index<1>> ret;
377     ret.set(0, n);
378     return ret;
379   }
380   inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
381     Eigen::IndexList<Eigen::type2index<1>, int> ret;
382     ret.set(1, m);
383     return ret;
384   }
385 #endif
386 
387   void BCast(const CPUDevice& dev,
388              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
389              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
390              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
391              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
392              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
393              bool* error) {
394     typedef typename Functor::in_type T;
395     typename Functor::func func;
396     if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
397       // Optimize for speed by using Eigen::type2index and avoid
398       // .broadcast() when we know its a no-op.
399       //
400       // Here, we need to handle 6 cases depending on how many "1"
401       // exist in in0 and in1's shapes (4 numbers in total). It's not
402       // possible that two shapes have more than 2 1s because those
403       // are simplified to NDIMS==1 case.
404       //
405       // Because this optimization increases the binary size for each
406       // Functor (+, -, *, /, <, <=, etc.), type and ndim combination.
407       // we only apply such optimization for selected ops/types/ndims.
408       //
409       // Because NDIMS, Functor::use_broadcast_optimization and
410       // use_broadcast_optimization<T> are compile-time constant, gcc
411       // does a decent job avoiding generating code when conditions
412       // are not met.
413       const int a = in0.dimension(0);  // in0 is shape [a, b]
414       const int b = in0.dimension(1);
415       const int c = in1.dimension(0);  // in1 is shape [c, d]
416       const int d = in1.dimension(1);
417       if ((a == 1) && (d == 1)) {
418         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
419         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
420         Assign(dev, out, lhs.binaryExpr(rhs, func));
421         return;
422       }
423       if ((b == 1) && (c == 1)) {
424         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
425         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
426         Assign(dev, out, lhs.binaryExpr(rhs, func));
427         return;
428       }
429       if (a == 1) {
430         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
431         auto rhs = in1;
432         Assign(dev, out, lhs.binaryExpr(rhs, func));
433         return;
434       }
435       if (b == 1) {
436         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
437         auto rhs = in1;
438         Assign(dev, out, lhs.binaryExpr(rhs, func));
439         return;
440       }
441       if (c == 1) {
442         auto lhs = in0;
443         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
444         Assign(dev, out, lhs.binaryExpr(rhs, func));
445         return;
446       }
447       if (d == 1) {
448         auto lhs = in0;
449         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
450         Assign(dev, out, lhs.binaryExpr(rhs, func));
451         return;
452       }
453 
454       const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
455       const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
456       if (bcast0_all_one && !bcast1_all_one) {
457         auto lhs = in0;  // No need to do broadcast for in0
458         auto rhs = in1.broadcast(bcast1);
459         Assign(dev, out, lhs.binaryExpr(rhs, func));
460         return;
461       }
462 
463       if (!bcast0_all_one && bcast1_all_one) {
464         auto lhs = in0.broadcast(bcast0);
465         auto rhs = in1;  // No need to do broadcast for in1
466         Assign(dev, out, lhs.binaryExpr(rhs, func));
467         return;
468       }
469     }
470 
471     // Fallback path. Always works and probably slower.
472     auto lhs = in0.broadcast(bcast0);
473     auto rhs = in1.broadcast(bcast1);
474     Assign(dev, out, lhs.binaryExpr(rhs, func));
475   }
476 };
477 
478 // Version of BinaryFunctor with error handling.
479 template <typename Functor, int NDIMS>
480 struct BinaryFunctor<CPUDevice, Functor, NDIMS, true> {
481   void operator()(const CPUDevice& d, typename Functor::tout_type out,
482                   typename Functor::tin_type in0,
483                   typename Functor::tin_type in1, bool* error) {
484     Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error)));
485   }
486 
487   void Left(const CPUDevice& d, typename Functor::tout_type out,
488             typename Functor::tscalar_type scalar,
489             typename Functor::tin_type in, bool* error) {
490     typedef typename Functor::out_type Tout;
491     typedef typename Functor::in_type Tin;
492     typedef typename Functor::func Binary;
493     typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
494     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
495   }
496 
497   void Right(const CPUDevice& d, typename Functor::tout_type out,
498              typename Functor::tin_type in,
499              typename Functor::tscalar_type scalar, bool* error) {
500     typedef typename Functor::out_type Tout;
501     typedef typename Functor::in_type Tin;
502     typedef typename Functor::func Binary;
503     typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
504     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
505   }
506 
507   void BCast(const CPUDevice& dev,
508              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
509              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
510              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
511              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
512              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
513              bool* error) {
514     typename Functor::func func(error);
515     auto lhs = in0.broadcast(bcast0);
516     auto rhs = in1.broadcast(bcast1);
517     Assign(dev, out, lhs.binaryExpr(rhs, func));
518   }
519 };
520 
521 // Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>.
522 template <typename Functor>
523 struct UnaryFunctor<CPUDevice, Functor> {
524   void operator()(const CPUDevice& d, typename Functor::tout_type out,
525                   typename Functor::tin_type in) {
526     Assign(d, out, in.unaryExpr(typename Functor::func()));
527   }
528 };
529 
530 // Partial specialization of ApproximateEqual<Device=CPUDevice, T>.
531 template <typename T>
532 struct ApproximateEqual<CPUDevice, T> {
533   void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat x,
534                   typename TTypes<T>::ConstFlat y, T tolerance,
535                   typename TTypes<bool>::Flat z) {
536     auto diff = x - y;
537     z.device(d) = diff.abs() <= tolerance;
538   }
539 };
540 
541 }  // end namespace functor
542 
543 #define REGISTER(OP, D, N, F, T)                                             \
544   REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
545                           OP<D##Device, F<T>>);
546 
547 #define REGISTER_VARIANT(OP, D, N, ENUM)                       \
548   REGISTER_KERNEL_BUILDER(                                     \
549       Name(N).Device(DEVICE_##D).TypeConstraint<Variant>("T"), \
550       OP<D##Device, ENUM>);
551 
552 // Macros to register kernels for multiple types (T0, T1, etc.)  on
553 // device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
554 // the functor "F" (e.g., functor::sqrt).
555 
556 #if defined(__ANDROID_TYPES_SLIM__)
557 // Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files.
558 // Normally Android TensorFlow is built with a reduced number of types (float).
559 // Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__"
560 // to generate a library with full type support with a consequent increase in
561 // code size.
562 #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
563 #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
564 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
565 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0)
566 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
567 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
568   REGISTER(OP, D, N, F, T0)
569 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
570   REGISTER(OP, D, N, F, T0)
571 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
572   REGISTER(OP, D, N, F, T0)
573 #else  // !defined(__ANDROID_TYPES_SLIM__)
574 #define REGISTER2(OP, D, N, F, T0, T1) \
575   REGISTER(OP, D, N, F, T0)            \
576   REGISTER(OP, D, N, F, T1)
577 #define REGISTER3(OP, D, N, F, T0, T1, T2) \
578   REGISTER2(OP, D, N, F, T0, T1)           \
579   REGISTER(OP, D, N, F, T2)
580 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
581   REGISTER2(OP, D, N, F, T0, T1)               \
582   REGISTER2(OP, D, N, F, T2, T3)
583 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
584   REGISTER3(OP, D, N, F, T0, T1, T2)               \
585   REGISTER2(OP, D, N, F, T3, T4)
586 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
587   REGISTER3(OP, D, N, F, T0, T1, T2)                   \
588   REGISTER3(OP, D, N, F, T3, T4, T5)
589 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
590   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                   \
591   REGISTER3(OP, D, N, F, T4, T5, T6)
592 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
593   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                       \
594   REGISTER4(OP, D, N, F, T4, T5, T6, T7)
595 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
596   REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4)                       \
597   REGISTER4(OP, D, N, F, T5, T6, T7, T8)
598 
599 // Instead of adding REGISTER10, etc., shard the .cc files - see
600 // cwise_op_equal_to_*.cc for an example.
601 
602 #endif  // defined(__ANDROID_TYPES_SLIM__)
603 
604 }  // end namespace tensorflow
605 
606 #endif  // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
607