1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/util/strided_slice_op.h" 17 #include "absl/types/span.h" 18 #include "tensorflow/compiler/tf2xla/literal_util.h" 19 #include "tensorflow/compiler/tf2xla/type_util.h" 20 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 23 #include "tensorflow/compiler/xla/client/xla_builder.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/ops_util.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/lib/core/status.h" 29 #include "tensorflow/core/platform/mem.h" 30 31 namespace tensorflow { 32 namespace { 33 34 class StridedSliceOp : public XlaOpKernel { 35 public: StridedSliceOp(OpKernelConstruction * ctx)36 explicit StridedSliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 37 OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); 38 OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); 39 OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); 40 OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); 41 OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); 42 OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); 43 } 44 Compile(XlaOpKernelContext * ctx)45 void Compile(XlaOpKernelContext* ctx) override { 46 const TensorShape input_shape = ctx->InputShape(0); 47 48 TensorShape final_shape; 49 absl::InlinedVector<int64, 4> begin; 50 absl::InlinedVector<int64, 4> end; 51 absl::InlinedVector<int64, 4> strides; 52 53 xla::Literal begin_literal, end_literal, strides_literal; 54 OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); 55 OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); 56 OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); 57 58 Tensor begin_tensor, end_tensor, strides_tensor; 59 OP_REQUIRES_OK( 60 ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); 61 OP_REQUIRES_OK(ctx, 62 LiteralToHostTensor(end_literal, index_type_, &end_tensor)); 63 OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, 64 &strides_tensor)); 65 66 TensorShape dummy_processing_shape; 67 bool dummy = false; 68 OP_REQUIRES_OK(ctx, 69 ValidateStridedSliceOp( 70 &begin_tensor, &end_tensor, strides_tensor, input_shape, 71 begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, 72 shrink_axis_mask_, &dummy_processing_shape, &final_shape, 73 &dummy, &dummy, &dummy, &begin, &end, &strides)); 74 75 absl::InlinedVector<int64, 4> dimensions_to_reverse; 76 absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides; 77 78 for (int i = 0; i < begin.size(); ++i) { 79 if (strides[i] > 0) { 80 slice_begin.push_back(begin[i]); 81 slice_end.push_back(std::max(end[i], begin[i])); 82 slice_strides.push_back(strides[i]); 83 } else { 84 // Negative stride: swap begin and end, add 1 because the interval 85 // is semi-open, and mark the dimension to be reversed. 86 slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); 87 slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, 88 input_shape.dim_size(i) - begin[i] - 1)); 89 slice_strides.push_back(-strides[i]); 90 dimensions_to_reverse.push_back(i); 91 } 92 } 93 94 xla::XlaOp slice = ctx->Input(0); 95 if (!dimensions_to_reverse.empty()) { 96 slice = xla::Rev(slice, dimensions_to_reverse); 97 } 98 99 slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); 100 101 slice = xla::Reshape(slice, final_shape.dim_sizes()); 102 ctx->SetOutput(0, slice); 103 } 104 105 private: 106 int32 begin_mask_, end_mask_; 107 int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; 108 DataType index_type_; 109 }; 110 111 REGISTER_XLA_OP(Name("StridedSlice") 112 .CompileTimeConstantInput("begin") 113 .CompileTimeConstantInput("end") 114 .CompileTimeConstantInput("strides"), 115 StridedSliceOp); 116 117 class StridedSliceGradOp : public XlaOpKernel { 118 public: StridedSliceGradOp(OpKernelConstruction * ctx)119 explicit StridedSliceGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 120 OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); 121 OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); 122 OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); 123 OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); 124 OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); 125 OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); 126 } 127 Compile(XlaOpKernelContext * ctx)128 void Compile(XlaOpKernelContext* ctx) override { 129 TensorShape processing_shape, final_shape; 130 absl::InlinedVector<int64, 4> begin; 131 absl::InlinedVector<int64, 4> end; 132 absl::InlinedVector<int64, 4> strides; 133 134 TensorShape input_shape; 135 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); 136 137 xla::Literal begin_literal, end_literal, strides_literal; 138 OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); 139 OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); 140 OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); 141 142 Tensor begin_tensor, end_tensor, strides_tensor; 143 OP_REQUIRES_OK( 144 ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); 145 OP_REQUIRES_OK(ctx, 146 LiteralToHostTensor(end_literal, index_type_, &end_tensor)); 147 OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, 148 &strides_tensor)); 149 150 bool dummy = false; 151 OP_REQUIRES_OK( 152 ctx, ValidateStridedSliceOp( 153 &begin_tensor, &end_tensor, strides_tensor, input_shape, 154 begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, 155 shrink_axis_mask_, &processing_shape, &final_shape, &dummy, 156 &dummy, &dummy, &begin, &end, &strides)); 157 158 // Check to make sure dy is consistent with the original slice 159 const TensorShape dy_shape = ctx->InputShape(4); 160 OP_REQUIRES( 161 ctx, final_shape == dy_shape, 162 errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(), 163 " instead of ", final_shape.DebugString())); 164 165 OP_REQUIRES( 166 ctx, input_shape.dims() == processing_shape.dims(), 167 errors::Internal( 168 "input shape and processing shape must have same number of dims")); 169 170 auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); 171 172 xla::XlaOp grad = ctx->Input(4); 173 174 // Undo any new/shrink axes. 175 grad = xla::Reshape(grad, processing_shape.dim_sizes()); 176 177 // Pad the input gradients. 178 absl::InlinedVector<int64, 4> dimensions_to_reverse; 179 xla::PaddingConfig padding_config; 180 181 for (int i = 0; i < processing_shape.dims(); ++i) { 182 auto* dims = padding_config.add_dimensions(); 183 if (strides[i] > 0) { 184 dims->set_edge_padding_low(begin[i]); 185 dims->set_interior_padding(strides[i] - 1); 186 187 // Pad the upper dimension up to the expected input shape. (It's 188 // not sufficient simply to use "end[i]" to compute the padding in 189 // cases where the stride does not divide evenly into the interval 190 // between begin[i] and end[i].) 191 int64 size = 192 dims->edge_padding_low() + processing_shape.dim_size(i) + 193 (processing_shape.dim_size(i) - 1) * dims->interior_padding(); 194 dims->set_edge_padding_high(input_shape.dim_size(i) - size); 195 } else { 196 dimensions_to_reverse.push_back(i); 197 dims->set_edge_padding_high(input_shape.dim_size(i) - begin[i] - 1); 198 dims->set_interior_padding(-strides[i] - 1); 199 200 // Pad the lower dimension up to the expected input shape. 201 int64 size = 202 dims->edge_padding_high() + processing_shape.dim_size(i) + 203 (processing_shape.dim_size(i) - 1) * dims->interior_padding(); 204 dims->set_edge_padding_low(input_shape.dim_size(i) - size); 205 } 206 } 207 if (!dimensions_to_reverse.empty()) { 208 grad = xla::Rev(grad, dimensions_to_reverse); 209 } 210 grad = xla::Pad(grad, zero, padding_config); 211 ctx->SetOutput(0, grad); 212 } 213 214 private: 215 int32 begin_mask_, end_mask_; 216 int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; 217 DataType index_type_; 218 }; 219 220 REGISTER_XLA_OP(Name("StridedSliceGrad") 221 .CompileTimeConstantInput("shape") 222 .CompileTimeConstantInput("begin") 223 .CompileTimeConstantInput("end") 224 .CompileTimeConstantInput("strides"), 225 StridedSliceGradOp); 226 227 class StridedSliceAssignOp : public XlaOpKernel { 228 public: StridedSliceAssignOp(OpKernelConstruction * ctx)229 explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 230 OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); 231 OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); 232 OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); 233 OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); 234 OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); 235 OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); 236 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 237 } 238 Compile(XlaOpKernelContext * ctx)239 void Compile(XlaOpKernelContext* ctx) override { 240 TensorShape final_shape; 241 absl::InlinedVector<int64, 4> begin; 242 absl::InlinedVector<int64, 4> end; 243 absl::InlinedVector<int64, 4> strides; 244 245 xla::Literal begin_literal, end_literal, strides_literal; 246 OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); 247 OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); 248 OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); 249 250 Tensor begin_tensor, end_tensor, strides_tensor; 251 OP_REQUIRES_OK( 252 ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); 253 OP_REQUIRES_OK(ctx, 254 LiteralToHostTensor(end_literal, index_type_, &end_tensor)); 255 OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, 256 &strides_tensor)); 257 258 TensorShape lhs_shape; 259 xla::XlaOp lhs; 260 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); 261 262 const TensorShape rhs_shape = ctx->InputShape(4); 263 264 TensorShape dummy_processing_shape; 265 bool dummy = false; 266 OP_REQUIRES_OK(ctx, 267 ValidateStridedSliceOp( 268 &begin_tensor, &end_tensor, strides_tensor, lhs_shape, 269 begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, 270 shrink_axis_mask_, &dummy_processing_shape, &final_shape, 271 &dummy, &dummy, &dummy, &begin, &end, &strides)); 272 273 if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) { 274 // DynamicUpdateSlice does not allow 0-element updates. We should probably 275 // check that rhs_shape can be broadcast to final_shape, but that is 276 // probably better handled when implementing broadcasting more generally. 277 return; 278 } 279 280 // TODO(aselle): This check is too strong, we only should need 281 // input_shape to be broadcastable to final_shape 282 OP_REQUIRES(ctx, final_shape == rhs_shape, 283 errors::Unimplemented( 284 "sliced l-value shape ", final_shape.DebugString(), 285 " does not match r-value shape ", rhs_shape.DebugString(), 286 ". Automatic broadcasting not yet implemented.")); 287 288 xla::XlaOp rhs = ctx->Input(4); 289 290 absl::InlinedVector<int64, 4> dimensions_to_reverse; 291 absl::InlinedVector<xla::XlaOp, 4> slice_begin; 292 absl::InlinedVector<int64, 4> slice_dims; 293 for (int i = 0; i < begin.size(); ++i) { 294 // TODO(b/121179231): implement strides != 1 295 OP_REQUIRES( 296 ctx, strides[i] == 1 || strides[i] == -1, 297 errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); 298 if (strides[i] > 0) { 299 slice_begin.push_back(xla::ConstantR0<int64>(ctx->builder(), begin[i])); 300 slice_dims.push_back(end[i] - begin[i]); 301 } else { 302 // Negative stride: swap begin and end, add 1 because the interval 303 // is semi-open, and mark the dimension to be reversed. 304 slice_begin.push_back( 305 xla::ConstantR0<int64>(ctx->builder(), end[i] + 1)); 306 slice_dims.push_back(begin[i] - end[i]); 307 dimensions_to_reverse.push_back(i); 308 } 309 } 310 311 if (!dimensions_to_reverse.empty()) { 312 rhs = xla::Rev(rhs, dimensions_to_reverse); 313 } 314 rhs = xla::Reshape(rhs, slice_dims); 315 316 lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin); 317 318 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); 319 } 320 321 private: 322 int32 begin_mask_, end_mask_; 323 int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; 324 DataType index_type_; 325 DataType dtype_; 326 }; 327 328 REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") 329 .CompileTimeConstantInput("begin") 330 .CompileTimeConstantInput("end") 331 .CompileTimeConstantInput("strides"), 332 StridedSliceAssignOp); 333 334 } // namespace 335 } // namespace tensorflow 336