• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific base classes for Unary and Binary Ops.
17 
18 #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
19 #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
20 
21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22 #include "tensorflow/compiler/xla/client/client_library.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/util/bcast.h"
26 
27 namespace tensorflow {
28 
29 // Coefficient-wise binary operations. Each binary Op expects two
30 // inputs that can be broadcast to the same shape. The base class
31 // contains pure virtual methods to override: description is a textual
32 // description of the operation; and Computation adds the
33 // implementation of the operation to a xla::XlaBuilder. For most
34 // arithmetic Ops XLA handles the broadcasting automatically given the input
35 // tensors.
36 class XlaBinaryOp : public XlaOpKernel {
37  public:
XlaBinaryOp(OpKernelConstruction * ctx)38   explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
39     const DataType lhs = BaseType(input_type(0));
40     const DataType rhs = BaseType(input_type(1));
41     OP_REQUIRES(ctx, lhs == rhs,
42                 errors::InvalidArgument("Input types of binary op must match"));
43   }
~XlaBinaryOp()44   ~XlaBinaryOp() override {}
45 
46   // Implement the (tensor,tensor)->tensor lambda that should be
47   // applied to the inputs. The desired computation should be added to
48   // 'tc->builder()' and '(lhs,rhs)' are the function's inputs and
49   // (lhs_shape,rhs_shape) are their respective
50   // shapes. 'broadcast_helper' contains metadata about the shapes of
51   // the inputs and the dimensions that need to be broadcast, which
52   // may be useful for Ops that can't use standard XLA automatic
53   // broadcasting. 'extend_dimension' is non-empty if lhs and rhs have
54   // different ranks, and indicates which dimensions of the
55   // higher-rank input should be matched when broadcasting the
56   // lower-rank input. See comment below and the documentation on broadcasting
57   // in the XLA documentation.
58   virtual xla::XlaOp Computation(
59       XlaOpKernelContext* ctx, const xla::XlaOp& lhs,
60       const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs,
61       const absl::Span<const int64>& rhs_shape, const BCast& broadcast_helper,
62       const std::vector<int64>& extend_dimensions) = 0;
63 
64   void Compile(XlaOpKernelContext* ctx) override;
65 
66   // Helper function that performs the broadcasting described by
67   // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
68   // shape.
69   static std::pair<xla::XlaOp, xla::XlaOp> Broadcast(
70       xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper);
71 };
72 
73 }  // namespace tensorflow
74 
75 #endif  // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
76