1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 19 #include "tensorflow/compiler/xla/client/xla_builder.h" 20 #include "tensorflow/compiler/xla/shape_util.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 23 namespace tensorflow { 24 namespace { 25 26 class MatrixBandPartOp : public XlaOpKernel { 27 public: MatrixBandPartOp(OpKernelConstruction * context)28 explicit MatrixBandPartOp(OpKernelConstruction* context) 29 : XlaOpKernel(context) {} 30 Compile(XlaOpKernelContext * context)31 void Compile(XlaOpKernelContext* context) override { 32 const TensorShape input_shape = context->InputShape(0); 33 // Preliminary validation of sizes. 34 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), 35 errors::InvalidArgument( 36 "input must be at least 2-dim, received shape: ", 37 input_shape.DebugString())); 38 39 const TensorShape num_lower_in_shape = context->InputShape(1); 40 OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in_shape), 41 errors::InvalidArgument("num_lower must be scalar, got shape ", 42 num_lower_in_shape.DebugString())); 43 44 const TensorShape num_upper_in_shape = context->InputShape(2); 45 OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in_shape), 46 errors::InvalidArgument("num_upper must be scalar, got shape ", 47 num_upper_in_shape.DebugString())); 48 49 xla::XlaBuilder* builder = context->builder(); 50 xla::XlaOp input = context->Input(0); 51 xla::XlaOp num_lower = context->Input(1); 52 xla::XlaOp num_upper = context->Input(2); 53 DataType input_type = context->input_type(0); 54 DataType index_type = context->input_type(1); 55 xla::PrimitiveType index_xla_type = context->input_xla_type(1); 56 57 TensorShape batch_shape = input_shape; 58 batch_shape.RemoveLastDims(2); 59 const int64 m = input_shape.dim_size(input_shape.dims() - 2); 60 const int64 n = input_shape.dim_size(input_shape.dims() - 1); 61 62 // Compute 'offset', which is how many diagonals we are above/below the 63 // diagonal. 64 xla::Shape iota_shape = xla::ShapeUtil::MakeShape(index_xla_type, {m, n}); 65 xla::XlaOp iota_m = xla::Iota(builder, iota_shape, /*iota_dimension=*/0); 66 xla::XlaOp iota_n = xla::Iota(builder, iota_shape, /*iota_dimension=*/1); 67 68 auto offset = xla::Sub(iota_n, iota_m); 69 70 // If num_lower or num_upper are negative, include all lower/upper 71 // diagonals. 72 auto zero_index = XlaHelpers::Zero(builder, index_type); 73 num_lower = xla::Select(xla::Lt(num_lower, zero_index), 74 XlaHelpers::IntegerLiteral(builder, index_type, m), 75 num_lower); 76 num_upper = xla::Select(xla::Lt(num_upper, zero_index), 77 XlaHelpers::IntegerLiteral(builder, index_type, n), 78 num_upper); 79 80 auto indicator = xla::And(xla::Le(xla::Neg(num_lower), offset), 81 xla::Le(offset, num_upper)); 82 indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); 83 84 auto zero_input = XlaHelpers::Zero(builder, input_type); 85 auto output = xla::Select( 86 indicator, input, xla::Broadcast(zero_input, input_shape.dim_sizes())); 87 88 context->SetOutput(0, output); 89 } 90 91 private: 92 TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp); 93 }; 94 REGISTER_XLA_OP(Name("MatrixBandPart"), MatrixBandPartOp); 95 96 } // namespace 97 } // namespace tensorflow 98