1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific base classes for Unary and Binary Ops. 17 18 #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ 19 #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ 20 21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 22 #include "tensorflow/compiler/xla/client/client_library.h" 23 #include "tensorflow/compiler/xla/client/xla_builder.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/util/bcast.h" 26 27 namespace tensorflow { 28 29 // Coefficient-wise binary operations. Each binary Op expects two 30 // inputs that can be broadcast to the same shape. The base class 31 // contains pure virtual methods to override: description is a textual 32 // description of the operation; and Computation adds the 33 // implementation of the operation to a xla::XlaBuilder. For most 34 // arithmetic Ops XLA handles the broadcasting automatically given the input 35 // tensors. 36 class XlaBinaryOp : public XlaOpKernel { 37 public: XlaBinaryOp(OpKernelConstruction * ctx)38 explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 39 const DataType lhs = BaseType(input_type(0)); 40 const DataType rhs = BaseType(input_type(1)); 41 OP_REQUIRES(ctx, lhs == rhs, 42 errors::InvalidArgument("Input types of binary op must match")); 43 } ~XlaBinaryOp()44 ~XlaBinaryOp() override {} 45 46 // Implement the (tensor,tensor)->tensor lambda that should be 47 // applied to the inputs. The desired computation should be added to 48 // 'tc->builder()' and '(lhs,rhs)' are the function's inputs and 49 // (lhs_shape,rhs_shape) are their respective 50 // shapes. 'broadcast_helper' contains metadata about the shapes of 51 // the inputs and the dimensions that need to be broadcast, which 52 // may be useful for Ops that can't use standard XLA automatic 53 // broadcasting. 'extend_dimension' is non-empty if lhs and rhs have 54 // different ranks, and indicates which dimensions of the 55 // higher-rank input should be matched when broadcasting the 56 // lower-rank input. See comment below and the documentation on broadcasting 57 // in the XLA documentation. 58 virtual xla::XlaOp Computation( 59 XlaOpKernelContext* ctx, const xla::XlaOp& lhs, 60 const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs, 61 const absl::Span<const int64>& rhs_shape, const BCast& broadcast_helper, 62 const std::vector<int64>& extend_dimensions) = 0; 63 64 void Compile(XlaOpKernelContext* ctx) override; 65 66 // Helper function that performs the broadcasting described by 67 // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same 68 // shape. 69 static std::pair<xla::XlaOp, xla::XlaOp> Broadcast( 70 xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper); 71 }; 72 73 } // namespace tensorflow 74 75 #endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ 76