1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
18
19 // See docs in ../ops/math_ops.cc.
20 #define _USE_MATH_DEFINES
21 #include <cmath>
22
23 #define EIGEN_USE_THREADS
24
25 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
26
27 #ifdef TENSORFLOW_USE_SYCL
28 #include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
29 #endif
30
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/tensor_types.h"
34 #include "tensorflow/core/framework/variant_op_registry.h"
35 #include "tensorflow/core/kernels/cwise_ops.h"
36 #include "tensorflow/core/kernels/cwise_ops_gradients.h"
37 #include "tensorflow/core/kernels/fill_functor.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/util/bcast.h"
40
41 namespace tensorflow {
42
43 typedef Eigen::ThreadPoolDevice CPUDevice;
44 typedef Eigen::GpuDevice GPUDevice;
45 #ifdef TENSORFLOW_USE_SYCL
46 typedef Eigen::SyclDevice SYCLDevice;
47 #endif
48
49 class BinaryOpShared : public OpKernel {
50 public:
51 explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in);
52
53 protected:
54 struct BinaryOpState {
55 // Sets up bcast with the shape of in0 and in1, ensures that the bcast
56 // is valid, and if so, set out, either by allocating a new buffer using
57 // ctx->output(...) or by creating an alias for an owned input buffer for
58 // in-place computation.
59 // Caller must check ctx->status() upon return for non-ok status.
60 // If ctx->status().ok() is true, then out is guaranteed to be allocated.
61 explicit BinaryOpState(OpKernelContext* ctx);
62
63 const Tensor& in0;
64 const Tensor& in1;
65
66 BCast bcast;
67 Tensor* out = nullptr;
68 int64 out_num_elements;
69
70 int64 in0_num_elements;
71 int64 in1_num_elements;
72
73 int ndims;
74 bool result;
75 };
76
77 void SetUnimplementedError(OpKernelContext* ctx);
78 void SetComputeError(OpKernelContext* ctx);
79 };
80
81 // Coefficient-wise binary operations:
82 // Device: E.g., CPUDevice, GPUDevice.
83 // Functor: defined in cwise_ops.h. E.g., functor::add.
84 template <typename Device, typename Functor>
85 class BinaryOp : public BinaryOpShared {
86 public:
87 typedef typename Functor::in_type Tin; // Input scalar data type.
88 typedef typename Functor::out_type Tout; // Output scalar data type.
89
BinaryOp(OpKernelConstruction * ctx)90 explicit BinaryOp(OpKernelConstruction* ctx)
91 : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(),
92 DataTypeToEnum<Tin>::v()) {}
93
Compute(OpKernelContext * ctx)94 void Compute(OpKernelContext* ctx) override {
95 const Tensor& input_0 = ctx->input(0);
96 const Tensor& input_1 = ctx->input(1);
97 const Device& eigen_device = ctx->eigen_device<Device>();
98 bool error = false;
99 bool* const error_ptr = Functor::has_errors ? &error : nullptr;
100
101 // NOTE: Handle three simple cases before building the BinaryOpState, which
102 // is relatively expensive for small operations.
103 if (input_0.shape() == input_1.shape()) {
104 // tensor op tensor with no broadcasting.
105 Tensor* out;
106 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
107 {0, 1}, 0, input_0.shape(), &out));
108 functor::BinaryFunctor<Device, Functor, 1>()(
109 eigen_device, out->template flat<Tout>(),
110 input_0.template flat<Tin>(), input_1.template flat<Tin>(),
111 error_ptr);
112 if (Functor::has_errors && error) {
113 SetComputeError(ctx);
114 }
115 return;
116 } else if (input_0.shape().dims() == 0) {
117 // scalar op tensor.
118 Tensor* out;
119 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
120 {1}, 0, input_1.shape(), &out));
121
122 functor::BinaryFunctor<Device, Functor, 1>().Left(
123 eigen_device, out->template flat<Tout>(),
124 input_0.template scalar<Tin>(), input_1.template flat<Tin>(),
125 error_ptr);
126 if (Functor::has_errors && error) {
127 SetComputeError(ctx);
128 }
129 return;
130 } else if (input_1.shape().dims() == 0) {
131 // tensor op scalar.
132 Tensor* out;
133 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
134 {0}, 0, input_0.shape(), &out));
135 functor::BinaryFunctor<Device, Functor, 1>().Right(
136 eigen_device, out->template flat<Tout>(),
137 input_0.template flat<Tin>(), input_1.template scalar<Tin>(),
138 error_ptr);
139 if (Functor::has_errors && error) {
140 SetComputeError(ctx);
141 }
142 return;
143 }
144
145 // 'state': Shared helper not dependent on T to reduce code size
146 BinaryOpState state(ctx);
147 if (ctx->status().code() == error::RESOURCE_EXHAUSTED) {
148 // Stop when BinaryOpState's constructor failed due to OOM.
149 return;
150 }
151 auto& bcast = state.bcast;
152 Tensor* out = state.out;
153 if (!bcast.IsValid()) {
154 if (ctx->status().ok()) {
155 if (state.result) {
156 functor::SetOneFunctor<Device, bool>()(eigen_device,
157 out->flat<bool>());
158 } else {
159 functor::SetZeroFunctor<Device, bool>()(eigen_device,
160 out->flat<bool>());
161 }
162 }
163 return;
164 }
165
166 auto& in0 = state.in0;
167 auto& in1 = state.in1;
168 if (state.out_num_elements == 0) {
169 return;
170 }
171
172 const int ndims = state.ndims;
173 if (ndims <= 1) {
174 auto out_flat = out->flat<Tout>();
175 if (state.in1_num_elements == 1) {
176 // tensor op scalar
177 functor::BinaryFunctor<Device, Functor, 1>().Right(
178 eigen_device, out_flat, in0.template flat<Tin>(),
179 in1.template scalar<Tin>(), error_ptr);
180 } else if (state.in0_num_elements == 1) {
181 // scalar op tensor
182 functor::BinaryFunctor<Device, Functor, 1>().Left(
183 eigen_device, out_flat, in0.template scalar<Tin>(),
184 in1.template flat<Tin>(), error_ptr);
185 } else {
186 functor::BinaryFunctor<Device, Functor, 1>()(
187 eigen_device, out_flat, in0.template flat<Tin>(),
188 in1.template flat<Tin>(), error_ptr);
189 }
190 } else if (ndims == 2) {
191 functor::BinaryFunctor<Device, Functor, 2>().BCast(
192 eigen_device, out->shaped<Tout, 2>(bcast.result_shape()),
193 in0.template shaped<Tin, 2>(bcast.x_reshape()),
194 BCast::ToIndexArray<2>(bcast.x_bcast()),
195 in1.template shaped<Tin, 2>(bcast.y_reshape()),
196 BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr);
197 } else if (ndims == 3) {
198 functor::BinaryFunctor<Device, Functor, 3>().BCast(
199 eigen_device, out->shaped<Tout, 3>(bcast.result_shape()),
200 in0.template shaped<Tin, 3>(bcast.x_reshape()),
201 BCast::ToIndexArray<3>(bcast.x_bcast()),
202 in1.template shaped<Tin, 3>(bcast.y_reshape()),
203 BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr);
204 } else if (ndims == 4) {
205 functor::BinaryFunctor<Device, Functor, 4>().BCast(
206 eigen_device, out->shaped<Tout, 4>(bcast.result_shape()),
207 in0.template shaped<Tin, 4>(bcast.x_reshape()),
208 BCast::ToIndexArray<4>(bcast.x_bcast()),
209 in1.template shaped<Tin, 4>(bcast.y_reshape()),
210 BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr);
211 } else if (ndims == 5) {
212 functor::BinaryFunctor<Device, Functor, 5>().BCast(
213 eigen_device, out->shaped<Tout, 5>(bcast.result_shape()),
214 in0.template shaped<Tin, 5>(bcast.x_reshape()),
215 BCast::ToIndexArray<5>(bcast.x_bcast()),
216 in1.template shaped<Tin, 5>(bcast.y_reshape()),
217 BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr);
218 } else {
219 SetUnimplementedError(ctx);
220 }
221 if (Functor::has_errors && error) {
222 SetComputeError(ctx);
223 }
224 }
225 };
226
227 template <typename Device, typename T>
228 class ApproximateEqualOp : public OpKernel {
229 public:
ApproximateEqualOp(OpKernelConstruction * context)230 explicit ApproximateEqualOp(OpKernelConstruction* context)
231 : OpKernel(context) {
232 float tolerance;
233 OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance));
234 tolerance_ = T(tolerance);
235 }
Compute(OpKernelContext * context)236 void Compute(OpKernelContext* context) override {
237 const Tensor& x_input = context->input(0);
238 const Tensor& y_input = context->input(1);
239 OP_REQUIRES(
240 context, x_input.shape() == y_input.shape(),
241 errors::InvalidArgument("x and y must be of the same shape. ",
242 "x shape: ", x_input.shape().DebugString(),
243 ". y shape: ", y_input.shape().DebugString()));
244 Tensor* z_output = nullptr;
245 OP_REQUIRES_OK(context,
246 context->allocate_output(0, x_input.shape(), &z_output));
247 const Device& d = context->eigen_device<Device>();
248 typename TTypes<T>::ConstFlat x(x_input.flat<T>());
249 typename TTypes<T>::ConstFlat y(y_input.flat<T>());
250 typename TTypes<bool>::Flat z(z_output->flat<bool>());
251 functor::ApproximateEqual<Device, T>()(d, x, y, tolerance_, z);
252 }
253
254 private:
255 T tolerance_;
256 };
257
258 // Basic coefficient-wise binary operations that are known to not require
259 // any broadcasting. This is the case for example of the gradients of
260 // unary operations.
261 // Device: E.g., CPUDevice, GPUDevice.
262 // Functor: defined above. E.g., functor::tanh_grad.
263 template <typename Device, typename Functor>
264 class SimpleBinaryOp : public OpKernel {
265 public:
266 typedef typename Functor::in_type Tin; // Input scalar data type.
267 typedef typename Functor::out_type Tout; // Output scalar data type.
268
SimpleBinaryOp(OpKernelConstruction * ctx)269 explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
270
Compute(OpKernelContext * ctx)271 void Compute(OpKernelContext* ctx) override {
272 const Tensor& in0 = ctx->input(0);
273 const Tensor& in1 = ctx->input(1);
274 auto in0_flat = in0.flat<Tin>();
275 auto in1_flat = in1.flat<Tin>();
276 const Device& eigen_device = ctx->eigen_device<Device>();
277
278 Tensor* out = nullptr;
279 if (std::is_same<Tin, Tout>::value) {
280 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
281 {0, 1}, 0, in0.shape(), &out));
282 } else {
283 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
284 }
285 auto out_flat = out->flat<Tout>();
286 functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
287 in0_flat, in1_flat);
288 }
289 };
290
291 // Coefficient-wise unary operations:
292 // Device: E.g., CPUDevice, GPUDevice.
293 // Functor: defined in cwise_ops.h. E.g., functor::sqrt.
294 template <typename Device, typename Functor>
295 class UnaryOp : public OpKernel {
296 public:
297 typedef typename Functor::in_type Tin; // Input scalar data type.
298 typedef typename Functor::out_type Tout; // Output scalar data type.
299 // Tin may be different from Tout. E.g., abs: complex64 -> float
300
UnaryOp(OpKernelConstruction * ctx)301 explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
302 auto in = DataTypeToEnum<Tin>::v();
303 auto out = DataTypeToEnum<Tout>::v();
304 OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
305 }
306
Compute(OpKernelContext * ctx)307 void Compute(OpKernelContext* ctx) override {
308 const Tensor& inp = ctx->input(0);
309 Tensor* out = nullptr;
310 if (std::is_same<Tin, Tout>::value) {
311 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
312 {0}, 0, inp.shape(), &out));
313 } else {
314 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
315 }
316 functor::UnaryFunctor<Device, Functor>()(
317 ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
318 }
319 };
320
321 template <typename Device, VariantUnaryOp OpEnum>
322 class UnaryVariantOp : public OpKernel {
323 public:
UnaryVariantOp(OpKernelConstruction * ctx)324 explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
325
Compute(OpKernelContext * ctx)326 void Compute(OpKernelContext* ctx) override {
327 const Tensor& inp = ctx->input(0);
328 OP_REQUIRES(
329 ctx, TensorShapeUtils::IsScalar(inp.shape()),
330 errors::InvalidArgument("Non-scalar variants are not supported."));
331 const Variant& v = inp.scalar<Variant>()();
332 Variant v_out;
333 OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out));
334 int numa_node = ctx->device()->NumaNode();
335 Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape());
336 out.scalar<Variant>()() = std::move(v_out);
337 ctx->set_output(0, std::move(out));
338 }
339 };
340
341 namespace functor {
342
343 template <typename D, typename Out, typename Rhs>
Assign(const D & d,Out out,Rhs rhs)344 void Assign(const D& d, Out out, Rhs rhs) {
345 out.device(d) = rhs;
346 }
347
348 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, NDIMS>
349 // for functors with with no error checking.
350 template <typename Functor, int NDIMS>
351 struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
352 void operator()(const CPUDevice& d, typename Functor::tout_type out,
353 typename Functor::tin_type in0,
354 typename Functor::tin_type in1, bool* error) {
355 Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
356 }
357
358 void Left(const CPUDevice& d, typename Functor::tout_type out,
359 typename Functor::tscalar_type scalar,
360 typename Functor::tin_type in, bool* error) {
361 typedef typename Functor::out_type Tout;
362 typedef typename Functor::in_type Tin;
363 typedef typename Functor::func Binary;
364 typedef
365 typename Eigen::internal::scalar_left<Tout, Tin, Binary,
366 /*is_scalar_in_host_memory=*/true>
367 Unary;
368 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
369 }
370
371 void Right(const CPUDevice& d, typename Functor::tout_type out,
372 typename Functor::tin_type in,
373 typename Functor::tscalar_type scalar, bool* error) {
374 typedef typename Functor::out_type Tout;
375 typedef typename Functor::in_type Tin;
376 typedef typename Functor::func Binary;
377 typedef typename Eigen::internal::scalar_right<
378 Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
379 Unary;
380 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
381 }
382
383 void BCast(const CPUDevice& dev,
384 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
385 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
386 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
387 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
388 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
389 bool* error) {
390 typename Functor::func func;
391 if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
392 Assign(dev, out, in0.binaryExpr(in1, func));
393 } else if (AllOne<NDIMS>(bcast0)) {
394 auto rhs = in1.broadcast(bcast1);
395 Assign(dev, out, in0.binaryExpr(rhs, func));
396 } else if (AllOne<NDIMS>(bcast1)) {
397 auto lhs = in0.broadcast(bcast0);
398 Assign(dev, out, lhs.binaryExpr(in1, func));
399 } else {
400 auto lhs = in0.broadcast(bcast0);
401 auto rhs = in1.broadcast(bcast1);
402 Assign(dev, out, lhs.binaryExpr(rhs, func));
403 }
404 }
405 };
406
407 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
408 // for functors with with no error checking.
409 template <typename Functor>
410 struct BinaryFunctor<CPUDevice, Functor, 2, false> {
411 enum { NDIMS = 2 };
412
413 void operator()(const CPUDevice& d, typename Functor::tout_type out,
414 typename Functor::tin_type in0,
415 typename Functor::tin_type in1, bool* error) {
416 Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
417 }
418
419 void Left(const CPUDevice& d, typename Functor::tout_type out,
420 typename Functor::tscalar_type scalar,
421 typename Functor::tin_type in, bool* error) {
422 typedef typename Functor::out_type Tout;
423 typedef typename Functor::in_type Tin;
424 typedef typename Functor::func Binary;
425 typedef
426 typename Eigen::internal::scalar_left<Tout, Tin, Binary,
427 /*is_scalar_in_host_memory=*/true>
428 Unary;
429 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
430 }
431
432 void Right(const CPUDevice& d, typename Functor::tout_type out,
433 typename Functor::tin_type in,
434 typename Functor::tscalar_type scalar, bool* error) {
435 typedef typename Functor::out_type Tout;
436 typedef typename Functor::in_type Tin;
437 typedef typename Functor::func Binary;
438 typedef typename Eigen::internal::scalar_right<
439 Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
440 Unary;
441 Assign(d, out, in.unaryExpr(Unary(scalar.data())));
442 }
443
444 #if !defined(EIGEN_HAS_INDEX_LIST)
445 inline Eigen::DSizes<int, 2> NByOne(int n) {
446 return Eigen::DSizes<int, 2>(n, 1);
447 }
448 inline Eigen::DSizes<int, 2> OneByM(int m) {
449 return Eigen::DSizes<int, 2>(1, m);
450 }
451 #else
452 inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
453 Eigen::IndexList<int, Eigen::type2index<1>> ret;
454 ret.set(0, n);
455 return ret;
456 }
457 inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
458 Eigen::IndexList<Eigen::type2index<1>, int> ret;
459 ret.set(1, m);
460 return ret;
461 }
462 #endif
463
464 void BCast(const CPUDevice& dev,
465 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
466 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
467 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
468 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
469 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
470 bool* error) {
471 typedef typename Functor::in_type T;
472 typename Functor::func func;
473 if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
474 // Optimize for speed by using Eigen::type2index and avoid
475 // .broadcast() when we know its a no-op.
476 //
477 // Here, we need to handle 6 cases depending on how many "1"
478 // exist in in0 and in1's shapes (4 numbers in total). It's not
479 // possible that two shapes have more than 2 1s because those
480 // are simplified to NDIMS==1 case.
481 //
482 // Because this optimization increases the binary size for each
483 // Functor (+, -, *, /, <, <=, etc.), type and ndim combination.
484 // we only apply such optimization for selected ops/types/ndims.
485 //
486 // Because NDIMS, Functor::use_broadcast_optimization and
487 // use_broadcast_optimization<T> are compile-time constant, gcc
488 // does a decent job avoiding generating code when conditions
489 // are not met.
490 const int a = in0.dimension(0); // in0 is shape [a, b]
491 const int b = in0.dimension(1);
492 const int c = in1.dimension(0); // in1 is shape [c, d]
493 const int d = in1.dimension(1);
494 if ((a == 1) && (d == 1)) {
495 auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
496 auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
497 Assign(dev, out, lhs.binaryExpr(rhs, func));
498 return;
499 }
500 if ((b == 1) && (c == 1)) {
501 auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
502 auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
503 Assign(dev, out, lhs.binaryExpr(rhs, func));
504 return;
505 }
506 if (a == 1) {
507 auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
508 auto rhs = in1;
509 Assign(dev, out, lhs.binaryExpr(rhs, func));
510 return;
511 }
512 if (b == 1) {
513 auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
514 auto rhs = in1;
515 Assign(dev, out, lhs.binaryExpr(rhs, func));
516 return;
517 }
518 if (c == 1) {
519 auto lhs = in0;
520 auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
521 Assign(dev, out, lhs.binaryExpr(rhs, func));
522 return;
523 }
524 if (d == 1) {
525 auto lhs = in0;
526 auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
527 Assign(dev, out, lhs.binaryExpr(rhs, func));
528 return;
529 }
530
531 const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
532 const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
533 if (bcast0_all_one && !bcast1_all_one) {
534 auto lhs = in0; // No need to do broadcast for in0
535 auto rhs = in1.broadcast(bcast1);
536 Assign(dev, out, lhs.binaryExpr(rhs, func));
537 return;
538 }
539
540 if (!bcast0_all_one && bcast1_all_one) {
541 auto lhs = in0.broadcast(bcast0);
542 auto rhs = in1; // No need to do broadcast for in1
543 Assign(dev, out, lhs.binaryExpr(rhs, func));
544 return;
545 }
546 }
547
548 // Fallback path. Always works and probably slower.
549 auto lhs = in0.broadcast(bcast0);
550 auto rhs = in1.broadcast(bcast1);
551 Assign(dev, out, lhs.binaryExpr(rhs, func));
552 }
553 };
554
555 // Version of BinaryFunctor with error handling.
556 template <typename Functor, int NDIMS>
557 struct BinaryFunctor<CPUDevice, Functor, NDIMS, true> {
558 void operator()(const CPUDevice& d, typename Functor::tout_type out,
559 typename Functor::tin_type in0,
560 typename Functor::tin_type in1, bool* error) {
561 Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error)));
562 }
563
564 void Left(const CPUDevice& d, typename Functor::tout_type out,
565 typename Functor::tscalar_type scalar,
566 typename Functor::tin_type in, bool* error) {
567 typedef typename Functor::out_type Tout;
568 typedef typename Functor::in_type Tin;
569 typedef typename Functor::func Binary;
570 typedef
571 typename Eigen::internal::scalar_left<Tout, Tin, Binary,
572 /*is_scalar_in_host_memory=*/true>
573 Unary;
574 Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
575 }
576
577 void Right(const CPUDevice& d, typename Functor::tout_type out,
578 typename Functor::tin_type in,
579 typename Functor::tscalar_type scalar, bool* error) {
580 typedef typename Functor::out_type Tout;
581 typedef typename Functor::in_type Tin;
582 typedef typename Functor::func Binary;
583 typedef typename Eigen::internal::scalar_right<
584 Tout, Tin, Binary, /*is_scalar_in_host_memory=*/true>
585 Unary;
586 Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
587 }
588
589 void BCast(const CPUDevice& dev,
590 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
591 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
592 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
593 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
594 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
595 bool* error) {
596 typename Functor::func func(error);
597 auto lhs = in0.broadcast(bcast0);
598 auto rhs = in1.broadcast(bcast1);
599 Assign(dev, out, lhs.binaryExpr(rhs, func));
600 }
601 };
602
603 // Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>.
604 template <typename Functor>
605 struct UnaryFunctor<CPUDevice, Functor> {
606 void operator()(const CPUDevice& d, typename Functor::tout_type out,
607 typename Functor::tin_type in) {
608 Assign(d, out, in.unaryExpr(typename Functor::func()));
609 }
610 };
611
612 // Partial specialization of ApproximateEqual<Device=CPUDevice, T>.
613 template <typename T>
614 struct ApproximateEqual<CPUDevice, T> {
615 void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat x,
616 typename TTypes<T>::ConstFlat y, T tolerance,
617 typename TTypes<bool>::Flat z) {
618 auto diff = x - y;
619 z.device(d) = diff.abs() <= tolerance;
620 }
621 };
622
623 } // end namespace functor
624
625 #define REGISTER(OP, D, N, F, T) \
626 REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
627 OP<D##Device, F<T>>);
628
629 #define REGISTER_VARIANT(OP, D, N, ENUM) \
630 REGISTER_KERNEL_BUILDER( \
631 Name(N).Device(DEVICE_##D).TypeConstraint<Variant>("T"), \
632 OP<D##Device, ENUM>);
633
634 // Macros to register kernels for multiple types (T0, T1, etc.) on
635 // device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
636 // the functor "F" (e.g., functor::sqrt).
637
638 #if defined(__ANDROID_TYPES_SLIM__)
639 // Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files.
640 // Normally Android TensorFlow is built with a reduced number of types (float).
641 // Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__"
642 // to generate a library with full type support with a consequent increase in
643 // code size.
644 #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
645 #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
646 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
647 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0)
648 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
649 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
650 REGISTER(OP, D, N, F, T0)
651 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
652 REGISTER(OP, D, N, F, T0)
653 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
654 REGISTER(OP, D, N, F, T0)
655 #else // !defined(__ANDROID_TYPES_SLIM__)
656 #define REGISTER2(OP, D, N, F, T0, T1) \
657 REGISTER(OP, D, N, F, T0) \
658 REGISTER(OP, D, N, F, T1)
659 #define REGISTER3(OP, D, N, F, T0, T1, T2) \
660 REGISTER2(OP, D, N, F, T0, T1) \
661 REGISTER(OP, D, N, F, T2)
662 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
663 REGISTER2(OP, D, N, F, T0, T1) \
664 REGISTER2(OP, D, N, F, T2, T3)
665 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
666 REGISTER3(OP, D, N, F, T0, T1, T2) \
667 REGISTER2(OP, D, N, F, T3, T4)
668 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
669 REGISTER3(OP, D, N, F, T0, T1, T2) \
670 REGISTER3(OP, D, N, F, T3, T4, T5)
671 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
672 REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
673 REGISTER3(OP, D, N, F, T4, T5, T6)
674 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
675 REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
676 REGISTER4(OP, D, N, F, T4, T5, T6, T7)
677 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
678 REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
679 REGISTER4(OP, D, N, F, T5, T6, T7, T8)
680
681 // Instead of adding REGISTER10, etc., shard the .cc files - see
682 // cwise_op_equal_to_*.cc for an example.
683
684 #endif // defined(__ANDROID_TYPES_SLIM__)
685
686 } // end namespace tensorflow
687
688 #endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
689