1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/kernels/while_op.h" 17 18 #include "tensorflow/compiler/tf2xla/shape_util.h" 19 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 #include "tensorflow/compiler/xla/client/xla_computation.h" 24 #include "tensorflow/core/framework/function.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 27 namespace tensorflow { 28 namespace { 29 30 class XlaSelectAndScatterOp : public XlaOpKernel { 31 public: XlaSelectAndScatterOp(OpKernelConstruction * context)32 explicit XlaSelectAndScatterOp(OpKernelConstruction* context) 33 : XlaOpKernel(context) { 34 OP_REQUIRES_OK(context, context->GetAttr("select", &select_computation_)); 35 OP_REQUIRES_OK(context, context->GetAttr("scatter", &scatter_computation_)); 36 } 37 Compile(XlaOpKernelContext * context)38 void Compile(XlaOpKernelContext* context) override { 39 const TensorShape input_shape = context->InputShape(0); 40 const DataType dtype = context->input_type(0); 41 42 std::vector<int64> window_dimensions; 43 std::vector<int64> window_strides; 44 OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( 45 "window_dimensions", &window_dimensions)); 46 OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", 47 &window_strides)); 48 49 const int rank = input_shape.dims(); 50 OP_REQUIRES(context, rank == window_dimensions.size(), 51 errors::InvalidArgument( 52 "The size of window_dimensions must be equal to the input " 53 "rank (", 54 window_dimensions.size(), " vs. ", rank, ")")); 55 OP_REQUIRES(context, rank == window_strides.size(), 56 errors::InvalidArgument( 57 "The size of window_strides must be equal to the input " 58 "rank (", 59 window_strides.size(), " vs. ", rank, ")")); 60 61 XlaCompiler::CompileOptions compile_options; 62 compile_options.use_tuple_arg = false; 63 compile_options.is_entry_computation = false; 64 compile_options.always_return_tuple = false; 65 66 // Build the select function. 67 XlaCompiler::Argument select_arg; 68 select_arg.kind = XlaCompiler::Argument::kParameter; 69 select_arg.type = dtype; 70 select_arg.shape = TensorShape(); 71 72 XlaCompiler::CompilationResult select; 73 OP_REQUIRES_OK(context, context->compiler()->CompileFunction( 74 compile_options, *select_computation_, 75 {select_arg, select_arg}, &select)); 76 77 xla::Shape select_output_shape = xla::ShapeUtil::MakeShape(xla::PRED, {}); 78 OP_REQUIRES( 79 context, 80 xla::ShapeUtil::Compatible(select.xla_output_shape, 81 select_output_shape), 82 errors::InvalidArgument( 83 "Invalid output shape of XlaSelectAndScatter select. Expected ", 84 xla::ShapeUtil::HumanString(select_output_shape), " got ", 85 xla::ShapeUtil::HumanString(select.xla_output_shape))); 86 87 // Build the scatter function. 88 XlaCompiler::Argument scatter_arg; 89 scatter_arg.kind = XlaCompiler::Argument::kParameter; 90 scatter_arg.type = dtype; 91 scatter_arg.shape = TensorShape(); 92 93 XlaCompiler::CompilationResult scatter; 94 OP_REQUIRES_OK(context, context->compiler()->CompileFunction( 95 compile_options, *scatter_computation_, 96 {scatter_arg, scatter_arg}, &scatter)); 97 98 xla::Shape scalar_shape; 99 OP_REQUIRES_OK(context, 100 TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); 101 OP_REQUIRES( 102 context, 103 xla::ShapeUtil::Compatible(scatter.xla_output_shape, scalar_shape), 104 errors::InvalidArgument( 105 "Invalid output shape of scatter. Expected ", 106 xla::ShapeUtil::HumanString(scalar_shape), " got ", 107 xla::ShapeUtil::HumanString(scatter.xla_output_shape))); 108 109 const TensorShape padding_shape = context->InputShape("padding"); 110 OP_REQUIRES(context, 111 TensorShapeUtils::IsMatrix(padding_shape) && 112 padding_shape.dim_size(1) == 2, 113 errors::InvalidArgument( 114 "padding must be a matrix with minor dimension 2, got ", 115 padding_shape.DebugString())); 116 xla::Literal padding_literal; 117 OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( 118 "padding", &padding_literal)); 119 std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0)); 120 for (int i = 0; i < padding.size(); ++i) { 121 padding[i] = {padding_literal.Get<int64>({i, 0}), 122 padding_literal.Get<int64>({i, 1})}; 123 } 124 125 xla::XlaOp output = xla::SelectAndScatterWithGeneralPadding( 126 context->Input("operand"), *select.computation, window_dimensions, 127 window_strides, padding, context->Input("source"), 128 context->Input("init_value"), *scatter.computation); 129 context->SetOutput(0, output); 130 } 131 132 private: 133 const NameAttrList* select_computation_; 134 const NameAttrList* scatter_computation_; 135 136 TF_DISALLOW_COPY_AND_ASSIGN(XlaSelectAndScatterOp); 137 }; 138 139 REGISTER_XLA_OP(Name("XlaSelectAndScatter") 140 .CompileTimeConstantInput("window_dimensions") 141 .CompileTimeConstantInput("window_strides") 142 .CompileTimeConstantInput("padding"), 143 XlaSelectAndScatterOp); 144 145 } // namespace 146 } // namespace tensorflow 147