• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/strided_slice_op.h"
17 #include "absl/types/span.h"
18 #include "tensorflow/compiler/tf2xla/literal_util.h"
19 #include "tensorflow/compiler/tf2xla/type_util.h"
20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/ops_util.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/platform/mem.h"
30 
31 namespace tensorflow {
32 namespace {
33 
34 class StridedSliceOp : public XlaOpKernel {
35  public:
StridedSliceOp(OpKernelConstruction * ctx)36   explicit StridedSliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
37     OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
38     OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
39     OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
40     OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
41     OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
42     OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
43   }
44 
Compile(XlaOpKernelContext * ctx)45   void Compile(XlaOpKernelContext* ctx) override {
46     const TensorShape input_shape = ctx->InputShape(0);
47 
48     TensorShape final_shape;
49     absl::InlinedVector<int64, 4> begin;
50     absl::InlinedVector<int64, 4> end;
51     absl::InlinedVector<int64, 4> strides;
52 
53     xla::Literal begin_literal, end_literal, strides_literal;
54     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
55     OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
56     OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
57 
58     Tensor begin_tensor, end_tensor, strides_tensor;
59     OP_REQUIRES_OK(
60         ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
61     OP_REQUIRES_OK(ctx,
62                    LiteralToHostTensor(end_literal, index_type_, &end_tensor));
63     OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
64                                             &strides_tensor));
65 
66     TensorShape dummy_processing_shape;
67     bool dummy = false;
68     OP_REQUIRES_OK(ctx,
69                    ValidateStridedSliceOp(
70                        &begin_tensor, &end_tensor, strides_tensor, input_shape,
71                        begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
72                        shrink_axis_mask_, &dummy_processing_shape, &final_shape,
73                        &dummy, &dummy, &dummy, &begin, &end, &strides));
74 
75     absl::InlinedVector<int64, 4> dimensions_to_reverse;
76     absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
77 
78     for (int i = 0; i < begin.size(); ++i) {
79       if (strides[i] > 0) {
80         slice_begin.push_back(begin[i]);
81         slice_end.push_back(std::max(end[i], begin[i]));
82         slice_strides.push_back(strides[i]);
83       } else {
84         // Negative stride: swap begin and end, add 1 because the interval
85         // is semi-open, and mark the dimension to be reversed.
86         slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1);
87         slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1,
88                                      input_shape.dim_size(i) - begin[i] - 1));
89         slice_strides.push_back(-strides[i]);
90         dimensions_to_reverse.push_back(i);
91       }
92     }
93 
94     xla::XlaOp slice = ctx->Input(0);
95     if (!dimensions_to_reverse.empty()) {
96       slice = xla::Rev(slice, dimensions_to_reverse);
97     }
98 
99     slice = xla::Slice(slice, slice_begin, slice_end, slice_strides);
100 
101     slice = xla::Reshape(slice, final_shape.dim_sizes());
102     ctx->SetOutput(0, slice);
103   }
104 
105  private:
106   int32 begin_mask_, end_mask_;
107   int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
108   DataType index_type_;
109 };
110 
111 REGISTER_XLA_OP(Name("StridedSlice")
112                     .CompileTimeConstantInput("begin")
113                     .CompileTimeConstantInput("end")
114                     .CompileTimeConstantInput("strides"),
115                 StridedSliceOp);
116 
117 class StridedSliceGradOp : public XlaOpKernel {
118  public:
StridedSliceGradOp(OpKernelConstruction * ctx)119   explicit StridedSliceGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
120     OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
121     OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
122     OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
123     OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
124     OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
125     OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
126   }
127 
Compile(XlaOpKernelContext * ctx)128   void Compile(XlaOpKernelContext* ctx) override {
129     TensorShape processing_shape, final_shape;
130     absl::InlinedVector<int64, 4> begin;
131     absl::InlinedVector<int64, 4> end;
132     absl::InlinedVector<int64, 4> strides;
133 
134     TensorShape input_shape;
135     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
136 
137     xla::Literal begin_literal, end_literal, strides_literal;
138     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
139     OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
140     OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
141 
142     Tensor begin_tensor, end_tensor, strides_tensor;
143     OP_REQUIRES_OK(
144         ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
145     OP_REQUIRES_OK(ctx,
146                    LiteralToHostTensor(end_literal, index_type_, &end_tensor));
147     OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
148                                             &strides_tensor));
149 
150     bool dummy = false;
151     OP_REQUIRES_OK(
152         ctx, ValidateStridedSliceOp(
153                  &begin_tensor, &end_tensor, strides_tensor, input_shape,
154                  begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
155                  shrink_axis_mask_, &processing_shape, &final_shape, &dummy,
156                  &dummy, &dummy, &begin, &end, &strides));
157 
158     // Check to make sure dy is consistent with the original slice
159     const TensorShape dy_shape = ctx->InputShape(4);
160     OP_REQUIRES(
161         ctx, final_shape == dy_shape,
162         errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(),
163                                 " instead of ", final_shape.DebugString()));
164 
165     OP_REQUIRES(
166         ctx, input_shape.dims() == processing_shape.dims(),
167         errors::Internal(
168             "input shape and processing shape must have same number of dims"));
169 
170     auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0));
171 
172     xla::XlaOp grad = ctx->Input(4);
173 
174     // Undo any new/shrink axes.
175     grad = xla::Reshape(grad, processing_shape.dim_sizes());
176 
177     // Pad the input gradients.
178     absl::InlinedVector<int64, 4> dimensions_to_reverse;
179     xla::PaddingConfig padding_config;
180 
181     for (int i = 0; i < processing_shape.dims(); ++i) {
182       auto* dims = padding_config.add_dimensions();
183       if (strides[i] > 0) {
184         dims->set_edge_padding_low(begin[i]);
185         dims->set_interior_padding(strides[i] - 1);
186 
187         // Pad the upper dimension up to the expected input shape. (It's
188         // not sufficient simply to use "end[i]" to compute the padding in
189         // cases where the stride does not divide evenly into the interval
190         // between begin[i] and end[i].)
191         int64 size =
192             dims->edge_padding_low() + processing_shape.dim_size(i) +
193             (processing_shape.dim_size(i) - 1) * dims->interior_padding();
194         dims->set_edge_padding_high(input_shape.dim_size(i) - size);
195       } else {
196         dimensions_to_reverse.push_back(i);
197         dims->set_edge_padding_high(input_shape.dim_size(i) - begin[i] - 1);
198         dims->set_interior_padding(-strides[i] - 1);
199 
200         // Pad the lower dimension up to the expected input shape.
201         int64 size =
202             dims->edge_padding_high() + processing_shape.dim_size(i) +
203             (processing_shape.dim_size(i) - 1) * dims->interior_padding();
204         dims->set_edge_padding_low(input_shape.dim_size(i) - size);
205       }
206     }
207     if (!dimensions_to_reverse.empty()) {
208       grad = xla::Rev(grad, dimensions_to_reverse);
209     }
210     grad = xla::Pad(grad, zero, padding_config);
211     ctx->SetOutput(0, grad);
212   }
213 
214  private:
215   int32 begin_mask_, end_mask_;
216   int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
217   DataType index_type_;
218 };
219 
220 REGISTER_XLA_OP(Name("StridedSliceGrad")
221                     .CompileTimeConstantInput("shape")
222                     .CompileTimeConstantInput("begin")
223                     .CompileTimeConstantInput("end")
224                     .CompileTimeConstantInput("strides"),
225                 StridedSliceGradOp);
226 
227 class StridedSliceAssignOp : public XlaOpKernel {
228  public:
StridedSliceAssignOp(OpKernelConstruction * ctx)229   explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
230     OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
231     OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
232     OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
233     OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
234     OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
235     OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
236     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
237   }
238 
Compile(XlaOpKernelContext * ctx)239   void Compile(XlaOpKernelContext* ctx) override {
240     TensorShape final_shape;
241     absl::InlinedVector<int64, 4> begin;
242     absl::InlinedVector<int64, 4> end;
243     absl::InlinedVector<int64, 4> strides;
244 
245     xla::Literal begin_literal, end_literal, strides_literal;
246     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
247     OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
248     OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
249 
250     Tensor begin_tensor, end_tensor, strides_tensor;
251     OP_REQUIRES_OK(
252         ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
253     OP_REQUIRES_OK(ctx,
254                    LiteralToHostTensor(end_literal, index_type_, &end_tensor));
255     OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
256                                             &strides_tensor));
257 
258     TensorShape lhs_shape;
259     xla::XlaOp lhs;
260     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
261 
262     const TensorShape rhs_shape = ctx->InputShape(4);
263 
264     TensorShape dummy_processing_shape;
265     bool dummy = false;
266     OP_REQUIRES_OK(ctx,
267                    ValidateStridedSliceOp(
268                        &begin_tensor, &end_tensor, strides_tensor, lhs_shape,
269                        begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
270                        shrink_axis_mask_, &dummy_processing_shape, &final_shape,
271                        &dummy, &dummy, &dummy, &begin, &end, &strides));
272 
273     if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) {
274       // DynamicUpdateSlice does not allow 0-element updates. We should probably
275       // check that rhs_shape can be broadcast to final_shape, but that is
276       // probably better handled when implementing broadcasting more generally.
277       return;
278     }
279 
280     // TODO(aselle): This check is too strong, we only should need
281     // input_shape to be broadcastable to final_shape
282     OP_REQUIRES(ctx, final_shape == rhs_shape,
283                 errors::Unimplemented(
284                     "sliced l-value shape ", final_shape.DebugString(),
285                     " does not match r-value shape ", rhs_shape.DebugString(),
286                     ". Automatic broadcasting not yet implemented."));
287 
288     xla::XlaOp rhs = ctx->Input(4);
289 
290     absl::InlinedVector<int64, 4> dimensions_to_reverse;
291     absl::InlinedVector<xla::XlaOp, 4> slice_begin;
292     absl::InlinedVector<int64, 4> slice_dims;
293     for (int i = 0; i < begin.size(); ++i) {
294       // TODO(b/121179231): implement strides != 1
295       OP_REQUIRES(
296           ctx, strides[i] == 1 || strides[i] == -1,
297           errors::Unimplemented("Strides != 1 or -1 are not yet implemented"));
298       if (strides[i] > 0) {
299         slice_begin.push_back(xla::ConstantR0<int64>(ctx->builder(), begin[i]));
300         slice_dims.push_back(end[i] - begin[i]);
301       } else {
302         // Negative stride: swap begin and end, add 1 because the interval
303         // is semi-open, and mark the dimension to be reversed.
304         slice_begin.push_back(
305             xla::ConstantR0<int64>(ctx->builder(), end[i] + 1));
306         slice_dims.push_back(begin[i] - end[i]);
307         dimensions_to_reverse.push_back(i);
308       }
309     }
310 
311     if (!dimensions_to_reverse.empty()) {
312       rhs = xla::Rev(rhs, dimensions_to_reverse);
313     }
314     rhs = xla::Reshape(rhs, slice_dims);
315 
316     lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin);
317 
318     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
319   }
320 
321  private:
322   int32 begin_mask_, end_mask_;
323   int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
324   DataType index_type_;
325   DataType dtype_;
326 };
327 
328 REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
329                     .CompileTimeConstantInput("begin")
330                     .CompileTimeConstantInput("end")
331                     .CompileTimeConstantInput("strides"),
332                 StridedSliceAssignOp);
333 
334 }  // namespace
335 }  // namespace tensorflow
336