1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
18
19 // See docs in ../ops/math_ops.cc.
20
21 #define EIGEN_USE_THREADS
22
23 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
24
25 #ifdef TENSORFLOW_USE_SYCL
26 #include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
27 #endif
28
29 #include "tensorflow/core/kernels/cwise_ops.h"
30 #include "tensorflow/core/kernels/cwise_ops_gradients.h"
31
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/variant_op_registry.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/util/bcast.h"
38
39 namespace tensorflow {
40
41 typedef Eigen::ThreadPoolDevice CPUDevice;
42 typedef Eigen::GpuDevice GPUDevice;
43 #ifdef TENSORFLOW_USE_SYCL
44 typedef Eigen::SyclDevice SYCLDevice;
45 #endif
46
47 class BinaryOpShared : public OpKernel {
48 public:
49 explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in);
50
51 protected:
52 struct BinaryOpState {
53 // Sets up bcast with the shape of in0 and in1, ensures that the bcast
54 // is valid, and if so, set out, either by allocating a new buffer using
55 // ctx->output(...) or by creating an alias for an owned input buffer for
56 // in-place computation.
57 // Caller must check ctx->status() upon return for non-ok status.
58 // If ctx->status().ok() is true, then out is guaranteed to be allocated.
59 BinaryOpState(OpKernelContext* ctx);
60
61 const Tensor& in0;
62 const Tensor& in1;
63
64 BCast bcast;
65 Tensor* out = nullptr;
66 int64 out_num_elements;
67
68 int64 in0_num_elements;
69 int64 in1_num_elements;
70
71 int ndims;
72 };
73
74 void SetUnimplementedError(OpKernelContext* ctx);
75 void SetComputeError(OpKernelContext* ctx);
76 };
77
78 // Coefficient-wise binary operations:
79 // Device: E.g., CPUDevice, GPUDevice.
80 // Functor: defined in cwise_ops.h. E.g., functor::add.
81 template <typename Device, typename Functor>
82 class BinaryOp : public BinaryOpShared {
83 public:
84 typedef typename Functor::in_type Tin; // Input scalar data type.
85 typedef typename Functor::out_type Tout; // Output scalar data type.
86
BinaryOp(OpKernelConstruction * ctx)87 explicit BinaryOp(OpKernelConstruction* ctx)
88 : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(),
89 DataTypeToEnum<Tin>::v()) {}
90
Compute(OpKernelContext * ctx)91 void Compute(OpKernelContext* ctx) override {
92 // 'state': Shared helper not dependent on T to reduce code size
93 BinaryOpState state(ctx);
94 if (!ctx->status().ok()) return;
95 Tensor* out = state.out;
96 BCast* bcast = &state.bcast;
97 auto& in0 = state.in0;
98 auto& in1 = state.in1;
99 if (state.out_num_elements == 0) {
100 return;
101 }
102 const int ndims = state.ndims;
103 const Device& eigen_device = ctx->eigen_device<Device>();
104 bool error = false;
105 bool* const error_ptr = Functor::has_errors ? &error : nullptr;
106 if (ndims <= 1) {
107 auto out_flat = out->flat<Tout>();
108 if (state.in1_num_elements == 1) {
109 // tensor op scalar
110 functor::BinaryFunctor<Device, Functor, 1>().Right(
111 eigen_device, out_flat, in0.template flat<Tin>(),
112 in1.template scalar<Tin>(), error_ptr);
113 } else if (state.in0_num_elements == 1) {
114 // scalar op tensor
115 functor::BinaryFunctor<Device, Functor, 1>().Left(
116 eigen_device, out_flat, in0.template scalar<Tin>(),
117 in1.template flat<Tin>(), error_ptr);
118 } else {
119 functor::BinaryFunctor<Device, Functor, 1>()(
120 eigen_device, out_flat, in0.template flat<Tin>(),
121 in1.template flat<Tin>(), error_ptr);
122 }
123 } else if (ndims == 2) {
124 functor::BinaryFunctor<Device, Functor, 2>().BCast(
125 eigen_device, out->shaped<Tout, 2>(bcast->result_shape()),
126 in0.template shaped<Tin, 2>(bcast->x_reshape()),
127 BCast::ToIndexArray<2>(bcast->x_bcast()),
128 in1.template shaped<Tin, 2>(bcast->y_reshape()),
129 BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr);
130 } else if (ndims == 3) {
131 functor::BinaryFunctor<Device, Functor, 3>().BCast(
132 eigen_device, out->shaped<Tout, 3>(bcast->result_shape()),
133 in0.template shaped<Tin, 3>(bcast->x_reshape()),
134 BCast::ToIndexArray<3>(bcast->x_bcast()),
135 in1.template shaped<Tin, 3>(bcast->y_reshape()),
136 BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr);
137 } else if (ndims == 4) {
138 functor::BinaryFunctor<Device, Functor, 4>().BCast(
139 eigen_device, out->shaped<Tout, 4>(bcast->result_shape()),
140 in0.template shaped<Tin, 4>(bcast->x_reshape()),
141 BCast::ToIndexArray<4>(bcast->x_bcast()),
142 in1.template shaped<Tin, 4>(bcast->y_reshape()),
143 BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr);
144 } else if (ndims == 5) {
145 functor::BinaryFunctor<Device, Functor, 5>().BCast(
146 eigen_device, out->shaped<Tout, 5>(bcast->result_shape()),
147 in0.template shaped<Tin, 5>(bcast->x_reshape()),
148 BCast::ToIndexArray<5>(bcast->x_bcast()),
149 in1.template shaped<Tin, 5>(bcast->y_reshape()),
150 BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr);
151 } else {
152 SetUnimplementedError(ctx);
153 }
154 if (Functor::has_errors && error) {
155 SetComputeError(ctx);
156 }
157 }
158 };
159
160 template <typename Device, typename T>
161 class ApproximateEqualOp : public OpKernel {
162 public:
ApproximateEqualOp(OpKernelConstruction * context)163 explicit ApproximateEqualOp(OpKernelConstruction* context)
164 : OpKernel(context) {
165 float tolerance;
166 OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance));
167 tolerance_ = T(tolerance);
168 }
Compute(OpKernelContext * context)169 void Compute(OpKernelContext* context) override {
170 const Tensor& x_input = context->input(0);
171 const Tensor& y_input = context->input(1);
172 OP_REQUIRES(
173 context, x_input.shape() == y_input.shape(),
174 errors::InvalidArgument("x and y must be of the same shape. ",
175 "x shape: ", x_input.shape().DebugString(),
176 ". y shape: ", y_input.shape().DebugString()));
177 Tensor* z_output = nullptr;
178 OP_REQUIRES_OK(context,
179 context->allocate_output(0, x_input.shape(), &z_output));
180 const Device& d = context->eigen_device<Device>();
181 typename TTypes<T>::ConstFlat x(x_input.flat<T>());
182 typename TTypes<T>::ConstFlat y(y_input.flat<T>());
183 typename TTypes<bool>::Flat z(z_output->flat<bool>());
184 functor::ApproximateEqual<Device, T>()(d, x, y, tolerance_, z);
185 }
186
187 private:
188 T tolerance_;
189 };
190
191 // Basic coefficient-wise binary operations that are known to not require
192 // any broadcasting. This is the case for example of the gradients of
193 // unary operations.
194 // Device: E.g., CPUDevice, GPUDevice.
195 // Functor: defined above. E.g., functor::tanh_grad.
196 template <typename Device, typename Functor>
197 class SimpleBinaryOp : public OpKernel {
198 public:
199 typedef typename Functor::in_type Tin; // Input scalar data type.
200 typedef typename Functor::out_type Tout; // Output scalar data type.
201
SimpleBinaryOp(OpKernelConstruction * ctx)202 explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
203
Compute(OpKernelContext * ctx)204 void Compute(OpKernelContext* ctx) override {
205 const Tensor& in0 = ctx->input(0);
206 const Tensor& in1 = ctx->input(1);
207 auto in0_flat = in0.flat<Tin>();
208 auto in1_flat = in1.flat<Tin>();
209 const Device& eigen_device = ctx->eigen_device<Device>();
210
211 Tensor* out = nullptr;
212 if (std::is_same<Tin, Tout>::value) {
213 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
214 {0, 1}, 0, in0.shape(), &out));
215 } else {
216 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
217 }
218 auto out_flat = out->flat<Tout>();
219 functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
220 in0_flat, in1_flat);
221 }
222 };
223
224 // Coefficient-wise unary operations:
225 // Device: E.g., CPUDevice, GPUDevice.
226 // Functor: defined in cwise_ops.h. E.g., functor::sqrt.
227 template <typename Device, typename Functor>
228 class UnaryOp : public OpKernel {
229 public:
230 typedef typename Functor::in_type Tin; // Input scalar data type.
231 typedef typename Functor::out_type Tout; // Output scalar data type.
232 // Tin may be different from Tout. E.g., abs: complex64 -> float
233
UnaryOp(OpKernelConstruction * ctx)234 explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
235 auto in = DataTypeToEnum<Tin>::v();
236 auto out = DataTypeToEnum<Tout>::v();
237 OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
238 }
239
Compute(OpKernelContext * ctx)240 void Compute(OpKernelContext* ctx) override {
241 const Tensor& inp = ctx->input(0);
242 Tensor* out = nullptr;
243 if (std::is_same<Tin, Tout>::value) {
244 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
245 {0}, 0, inp.shape(), &out));
246 } else {
247 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
248 }
249 functor::UnaryFunctor<Device, Functor>()(
250 ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
251 }
252 };
253
254 template <typename Device, VariantUnaryOp OpEnum>
255 class UnaryVariantOp : public OpKernel {
256 public:
UnaryVariantOp(OpKernelConstruction * ctx)257 explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
258
Compute(OpKernelContext * ctx)259 void Compute(OpKernelContext* ctx) override {
260 const Tensor& inp = ctx->input(0);
261 OP_REQUIRES(
262 ctx, TensorShapeUtils::IsScalar(inp.shape()),
263 errors::InvalidArgument("Non-scalar variants are not supported."));
264 const Variant& v = inp.scalar<Variant>()();
265 Variant v_out;
266 OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out));
267 int numa_node = DeviceNumaNode(ctx->device());
268 Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape());
269 out.scalar<Variant>()() = std::move(v_out);
270 ctx->set_output(0, std::move(out));
271 }
272 };
273
274 namespace functor {
275
276 template <typename D, typename Out, typename Rhs>
Assign(const D & d,Out out,Rhs rhs)277 void Assign(const D& d, Out out, Rhs rhs) {
278 out.device(d) = rhs;
279 }
280
281 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, NDIMS>
282 // for functors with with no error checking.
283 template <typename Functor, int NDIMS>
284 struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
285 void operator()(const CPUDevice& d, typename Functor::tout_type out,
286 typename Functor::tin_type in0,
287 typename Functor::tin_type in1, bool* error) {
288 Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
289 }
290
291 void Left(const CPUDevice& d, typename Functor::tout_type out,
292 typename Functor::tscalar_type scalar,
293 typename Functor::tin_type in, bool* error) {
294 typedef typename Functor::out_type Tout;
295 typedef typename Functor::in_type Tin;
296 typedef typename Functor::func Binary;
297 typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
298 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
299 }
300
301 void Right(const CPUDevice& d, typename Functor::tout_type out,
302 typename Functor::tin_type in,
303 typename Functor::tscalar_type scalar, bool* error) {
304 typedef typename Functor::out_type Tout;
305 typedef typename Functor::in_type Tin;
306 typedef typename Functor::func Binary;
307 typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
308 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
309 }
310
311 void BCast(const CPUDevice& dev,
312 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
313 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
314 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
315 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
316 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
317 bool* error) {
318 typename Functor::func func;
319 if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
320 Assign(dev, out, in0.binaryExpr(in1, func));
321 } else if (AllOne<NDIMS>(bcast0)) {
322 auto rhs = in1.broadcast(bcast1);
323 Assign(dev, out, in0.binaryExpr(rhs, func));
324 } else if (AllOne<NDIMS>(bcast1)) {
325 auto lhs = in0.broadcast(bcast0);
326 Assign(dev, out, lhs.binaryExpr(in1, func));
327 } else {
328 auto lhs = in0.broadcast(bcast0);
329 auto rhs = in1.broadcast(bcast1);
330 Assign(dev, out, lhs.binaryExpr(rhs, func));
331 }
332 }
333 };
334
335 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
336 // for functors with with no error checking.
337 template <typename Functor>
338 struct BinaryFunctor<CPUDevice, Functor, 2, false> {
339 enum { NDIMS = 2 };
340
341 void operator()(const CPUDevice& d, typename Functor::tout_type out,
342 typename Functor::tin_type in0,
343 typename Functor::tin_type in1, bool* error) {
344 Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
345 }
346
347 void Left(const CPUDevice& d, typename Functor::tout_type out,
348 typename Functor::tscalar_type scalar,
349 typename Functor::tin_type in, bool* error) {
350 typedef typename Functor::out_type Tout;
351 typedef typename Functor::in_type Tin;
352 typedef typename Functor::func Binary;
353 typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
354 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
355 }
356
357 void Right(const CPUDevice& d, typename Functor::tout_type out,
358 typename Functor::tin_type in,
359 typename Functor::tscalar_type scalar, bool* error) {
360 typedef typename Functor::out_type Tout;
361 typedef typename Functor::in_type Tin;
362 typedef typename Functor::func Binary;
363 typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
364 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
365 }
366
367 #if !defined(EIGEN_HAS_INDEX_LIST)
368 inline Eigen::DSizes<int, 2> NByOne(int n) {
369 return Eigen::DSizes<int, 2>(n, 1);
370 }
371 inline Eigen::DSizes<int, 2> OneByM(int m) {
372 return Eigen::DSizes<int, 2>(1, m);
373 }
374 #else
375 inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
376 Eigen::IndexList<int, Eigen::type2index<1>> ret;
377 ret.set(0, n);
378 return ret;
379 }
380 inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
381 Eigen::IndexList<Eigen::type2index<1>, int> ret;
382 ret.set(1, m);
383 return ret;
384 }
385 #endif
386
387 void BCast(const CPUDevice& dev,
388 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
389 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
390 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
391 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
392 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
393 bool* error) {
394 typedef typename Functor::in_type T;
395 typename Functor::func func;
396 if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
397 // Optimize for speed by using Eigen::type2index and avoid
398 // .broadcast() when we know its a no-op.
399 //
400 // Here, we need to handle 6 cases depending on how many "1"
401 // exist in in0 and in1's shapes (4 numbers in total). It's not
402 // possible that two shapes have more than 2 1s because those
403 // are simplified to NDIMS==1 case.
404 //
405 // Because this optimization increases the binary size for each
406 // Functor (+, -, *, /, <, <=, etc.), type and ndim combination.
407 // we only apply such optimization for selected ops/types/ndims.
408 //
409 // Because NDIMS, Functor::use_broadcast_optimization and
410 // use_broadcast_optimization<T> are compile-time constant, gcc
411 // does a decent job avoiding generating code when conditions
412 // are not met.
413 const int a = in0.dimension(0); // in0 is shape [a, b]
414 const int b = in0.dimension(1);
415 const int c = in1.dimension(0); // in1 is shape [c, d]
416 const int d = in1.dimension(1);
417 if ((a == 1) && (d == 1)) {
418 auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
419 auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
420 Assign(dev, out, lhs.binaryExpr(rhs, func));
421 return;
422 }
423 if ((b == 1) && (c == 1)) {
424 auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
425 auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
426 Assign(dev, out, lhs.binaryExpr(rhs, func));
427 return;
428 }
429 if (a == 1) {
430 auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
431 auto rhs = in1;
432 Assign(dev, out, lhs.binaryExpr(rhs, func));
433 return;
434 }
435 if (b == 1) {
436 auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
437 auto rhs = in1;
438 Assign(dev, out, lhs.binaryExpr(rhs, func));
439 return;
440 }
441 if (c == 1) {
442 auto lhs = in0;
443 auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
444 Assign(dev, out, lhs.binaryExpr(rhs, func));
445 return;
446 }
447 if (d == 1) {
448 auto lhs = in0;
449 auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
450 Assign(dev, out, lhs.binaryExpr(rhs, func));
451 return;
452 }
453
454 const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
455 const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
456 if (bcast0_all_one && !bcast1_all_one) {
457 auto lhs = in0; // No need to do broadcast for in0
458 auto rhs = in1.broadcast(bcast1);
459 Assign(dev, out, lhs.binaryExpr(rhs, func));
460 return;
461 }
462
463 if (!bcast0_all_one && bcast1_all_one) {
464 auto lhs = in0.broadcast(bcast0);
465 auto rhs = in1; // No need to do broadcast for in1
466 Assign(dev, out, lhs.binaryExpr(rhs, func));
467 return;
468 }
469 }
470
471 // Fallback path. Always works and probably slower.
472 auto lhs = in0.broadcast(bcast0);
473 auto rhs = in1.broadcast(bcast1);
474 Assign(dev, out, lhs.binaryExpr(rhs, func));
475 }
476 };
477
478 // Version of BinaryFunctor with error handling.
479 template <typename Functor, int NDIMS>
480 struct BinaryFunctor<CPUDevice, Functor, NDIMS, true> {
481 void operator()(const CPUDevice& d, typename Functor::tout_type out,
482 typename Functor::tin_type in0,
483 typename Functor::tin_type in1, bool* error) {
484 Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error)));
485 }
486
487 void Left(const CPUDevice& d, typename Functor::tout_type out,
488 typename Functor::tscalar_type scalar,
489 typename Functor::tin_type in, bool* error) {
490 typedef typename Functor::out_type Tout;
491 typedef typename Functor::in_type Tin;
492 typedef typename Functor::func Binary;
493 typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
494 Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
495 }
496
497 void Right(const CPUDevice& d, typename Functor::tout_type out,
498 typename Functor::tin_type in,
499 typename Functor::tscalar_type scalar, bool* error) {
500 typedef typename Functor::out_type Tout;
501 typedef typename Functor::in_type Tin;
502 typedef typename Functor::func Binary;
503 typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
504 Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
505 }
506
507 void BCast(const CPUDevice& dev,
508 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
509 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
510 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
511 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
512 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
513 bool* error) {
514 typename Functor::func func(error);
515 auto lhs = in0.broadcast(bcast0);
516 auto rhs = in1.broadcast(bcast1);
517 Assign(dev, out, lhs.binaryExpr(rhs, func));
518 }
519 };
520
521 // Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>.
522 template <typename Functor>
523 struct UnaryFunctor<CPUDevice, Functor> {
524 void operator()(const CPUDevice& d, typename Functor::tout_type out,
525 typename Functor::tin_type in) {
526 Assign(d, out, in.unaryExpr(typename Functor::func()));
527 }
528 };
529
530 // Partial specialization of ApproximateEqual<Device=CPUDevice, T>.
531 template <typename T>
532 struct ApproximateEqual<CPUDevice, T> {
533 void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat x,
534 typename TTypes<T>::ConstFlat y, T tolerance,
535 typename TTypes<bool>::Flat z) {
536 auto diff = x - y;
537 z.device(d) = diff.abs() <= tolerance;
538 }
539 };
540
541 } // end namespace functor
542
543 #define REGISTER(OP, D, N, F, T) \
544 REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
545 OP<D##Device, F<T>>);
546
547 #define REGISTER_VARIANT(OP, D, N, ENUM) \
548 REGISTER_KERNEL_BUILDER( \
549 Name(N).Device(DEVICE_##D).TypeConstraint<Variant>("T"), \
550 OP<D##Device, ENUM>);
551
552 // Macros to register kernels for multiple types (T0, T1, etc.) on
553 // device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
554 // the functor "F" (e.g., functor::sqrt).
555
556 #if defined(__ANDROID_TYPES_SLIM__)
557 // Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files.
558 // Normally Android TensorFlow is built with a reduced number of types (float).
559 // Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__"
560 // to generate a library with full type support with a consequent increase in
561 // code size.
562 #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
563 #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
564 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
565 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0)
566 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
567 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
568 REGISTER(OP, D, N, F, T0)
569 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
570 REGISTER(OP, D, N, F, T0)
571 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
572 REGISTER(OP, D, N, F, T0)
573 #else // !defined(__ANDROID_TYPES_SLIM__)
574 #define REGISTER2(OP, D, N, F, T0, T1) \
575 REGISTER(OP, D, N, F, T0) \
576 REGISTER(OP, D, N, F, T1)
577 #define REGISTER3(OP, D, N, F, T0, T1, T2) \
578 REGISTER2(OP, D, N, F, T0, T1) \
579 REGISTER(OP, D, N, F, T2)
580 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
581 REGISTER2(OP, D, N, F, T0, T1) \
582 REGISTER2(OP, D, N, F, T2, T3)
583 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
584 REGISTER3(OP, D, N, F, T0, T1, T2) \
585 REGISTER2(OP, D, N, F, T3, T4)
586 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
587 REGISTER3(OP, D, N, F, T0, T1, T2) \
588 REGISTER3(OP, D, N, F, T3, T4, T5)
589 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
590 REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
591 REGISTER3(OP, D, N, F, T4, T5, T6)
592 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
593 REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
594 REGISTER4(OP, D, N, F, T4, T5, T6, T7)
595 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
596 REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
597 REGISTER4(OP, D, N, F, T5, T6, T7, T8)
598
599 // Instead of adding REGISTER10, etc., shard the .cc files - see
600 // cwise_op_equal_to_*.cc for an example.
601
602 #endif // defined(__ANDROID_TYPES_SLIM__)
603
604 } // end namespace tensorflow
605
606 #endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
607