• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #if GOOGLE_CUDA
21 #define EIGEN_USE_GPU
22 #endif  // GOOGLE_CUDA
23 
24 #include "tensorflow/core/kernels/strided_slice_op.h"
25 #include "tensorflow/core/kernels/dense_update_functor.h"
26 #include "tensorflow/core/kernels/slice_op.h"
27 #include "tensorflow/core/kernels/strided_slice_op_impl.h"
28 
29 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
30 #include "tensorflow/core/framework/bounds_check.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/kernels/ops_util.h"
35 #include "tensorflow/core/kernels/training_op_helpers.h"
36 #include "tensorflow/core/kernels/variable_ops.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/lib/gtl/array_slice.h"
39 #include "tensorflow/core/platform/prefetch.h"
40 #include "tensorflow/core/util/strided_slice_op.h"
41 
42 namespace tensorflow {
43 namespace {
44 
45 template <typename T>
46 struct MemCpyFunctor {
47   // Returns true if the copy was made with memcpy, false otherwise.
Copytensorflow::__anon010438490111::MemCpyFunctor48   bool Copy(const Tensor& input, const gtl::InlinedVector<int64, 4>& begin,
49             const gtl::InlinedVector<int64, 4>& end, Tensor* result) {
50     if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
51       auto in = input.tensor<T, 2>();
52       auto output = result->tensor<T, 2>();
53       // TODO(agarwal): Consider multi-threading if size[0] is large
54       for (int row_in = begin[0], row_out = 0; row_in < end[0];
55            ++row_in, ++row_out) {
56         if (row_in + 1 < end[0]) {
57           port::prefetch<port::PREFETCH_HINT_T0>(&output(row_in + 1, 0));
58           port::prefetch<port::PREFETCH_HINT_T0>(&in(row_in + 1, begin[1]));
59         }
60         memcpy(&output(row_out, 0), &in(row_in, begin[1]),
61                (end[1] - begin[1]) * sizeof(T));
62       }
63       return true;
64     }
65     return false;
66   }
67 };
68 
69 template <>
70 struct MemCpyFunctor<ResourceHandle> {
Copytensorflow::__anon010438490111::MemCpyFunctor71   bool Copy(const Tensor& input, const gtl::InlinedVector<int64, 4>& begin,
72             const gtl::InlinedVector<int64, 4>& end, Tensor* result) {
73     return false;
74   }
75 };
76 
77 }  // namespace
78 
79 template <typename Device, typename T>
80 class StridedSliceOp : public OpKernel {
81  public:
StridedSliceOp(OpKernelConstruction * context)82   explicit StridedSliceOp(OpKernelConstruction* context) : OpKernel(context) {
83     OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask));
84     OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask));
85     OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask));
86     OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask));
87     OP_REQUIRES_OK(context,
88                    context->GetAttr("shrink_axis_mask", &shrink_axis_mask));
89   }
90 
Compute(OpKernelContext * context)91   void Compute(OpKernelContext* context) override {
92     TensorShape processing_shape, final_shape;
93     bool is_identity = true;
94     bool slice_dim0 = true;
95     bool is_simple_slice = true;
96     gtl::InlinedVector<int64, 4> begin;
97     gtl::InlinedVector<int64, 4> end;
98     gtl::InlinedVector<int64, 4> strides;
99 
100     OP_REQUIRES_OK(
101         context, ValidateStridedSliceOp(
102                      &context->input(1), &context->input(2), context->input(3),
103                      context->input(0).shape(), begin_mask, end_mask,
104                      ellipsis_mask, new_axis_mask, shrink_axis_mask,
105                      &processing_shape, &final_shape, &is_identity,
106                      &is_simple_slice, &slice_dim0, &begin, &end, &strides));
107     const Tensor& input = context->input(0);
108 
109     // Optimization #1, slice is a no-op plus reshape
110     if (is_identity) {
111       VLOG(1) << "Strided slice identity ";
112       Tensor tmp;
113       OP_REQUIRES(context, tmp.CopyFrom(input, final_shape),
114                   errors::Internal("Copy failed"));
115       context->set_output(0, tmp);
116       return;
117     }
118 
119     // Optimization #2, slice is memory contiguous (only occurs in dim 0)
120     if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), begin[0], end[0])) {
121       OP_REQUIRES(context, input.dims() >= 1,
122                   errors::InvalidArgument(
123                       "Input must have rank at least 1, got: ", input.dims()));
124       // Otherwise, is_identity should be true.
125       VLOG(1) << "Strided slice dim 0: " << input.shape().DebugString();
126       // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
127       Tensor slice = input.Slice(std::min(begin[0], end[0]), end[0]);
128       Tensor tmp;
129       OP_REQUIRES(context, tmp.CopyFrom(slice, final_shape),
130                   errors::Internal("Copy failed"));
131       context->set_output(0, tmp);
132       return;
133     }
134 
135     Tensor* result = nullptr;
136     OP_REQUIRES_OK(context, context->allocate_output(0, final_shape, &result));
137     const int input_dims = input.dims();
138     const int processing_dims = processing_shape.dims();
139 
140     if (processing_shape.num_elements() > 0) {
141       // Optimization #3, slice has stride 1 in all dimensions
142       // Optimization #3A, slice has only two dimensions
143       // TODO(aselle): Here we are restricting to processing_shape and
144       // final_shape being 2D. This isn't strictly necessary, but I don't
145       // want to blow up code gen size, because to shape<> you need static
146       // NDIM and T
147       if (is_simple_slice && std::is_same<Device, CPUDevice>::value &&
148           input_dims == 2 && processing_shape.dims() == 2 &&
149           final_shape.dims() == 2 && new_axis_mask == 0) {
150         MemCpyFunctor<T> functor;
151         if (functor.Copy(input, begin, end, result)) {
152           return;
153         }
154       }
155 
156 #define HANDLE_DIM(NDIM)                                                       \
157   if (processing_dims == NDIM) {                                               \
158     HandleStridedSliceCase<Device, T, NDIM>(context, begin, end, strides,      \
159                                             processing_shape, is_simple_slice, \
160                                             result);                           \
161     return;                                                                    \
162   }
163 
164       HANDLE_DIM(1);
165       HANDLE_DIM(2);
166       HANDLE_DIM(3);
167       HANDLE_DIM(4);
168       HANDLE_DIM(5);
169       HANDLE_DIM(6);
170       HANDLE_DIM(7);
171 
172 #undef HANDLE_DIM
173 
174       OP_REQUIRES(
175           context, false,
176           errors::Unimplemented("Unhandled input dimensions ", input_dims));
177     }
178   }
179 
180  private:
181   int32 begin_mask, end_mask;
182   int32 ellipsis_mask, new_axis_mask, shrink_axis_mask;
183 };
184 
185 template <typename Device, typename T>
186 class StridedSliceGradOp : public OpKernel {
187  public:
StridedSliceGradOp(OpKernelConstruction * context)188   explicit StridedSliceGradOp(OpKernelConstruction* context)
189       : OpKernel(context) {
190     OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask));
191     OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask));
192     OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask));
193     OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask));
194     OP_REQUIRES_OK(context,
195                    context->GetAttr("shrink_axis_mask", &shrink_axis_mask));
196   }
197 
Compute(OpKernelContext * context)198   void Compute(OpKernelContext* context) override {
199     TensorShape processing_shape, final_shape;
200     bool is_identity = true;
201     bool slice_dim0 = true;
202     bool is_simple_slice = true;
203     gtl::InlinedVector<int64, 4> begin;
204     gtl::InlinedVector<int64, 4> end;
205     gtl::InlinedVector<int64, 4> strides;
206 
207     TensorShape input_shape;
208     const Tensor& input_shape_tensor = context->input(0);
209     OP_REQUIRES(
210         context, input_shape_tensor.dims() == 1,
211         errors::InvalidArgument("shape must be 1-D, got shape.shape = ",
212                                 input_shape_tensor.shape().DebugString()));
213     if (input_shape_tensor.dtype() == DT_INT32) {
214       OP_REQUIRES_OK(
215           context, TensorShapeUtils::MakeShape(input_shape_tensor.vec<int32>(),
216                                                &input_shape));
217     } else if (input_shape_tensor.dtype() == DT_INT64) {
218       OP_REQUIRES_OK(
219           context, TensorShapeUtils::MakeShape(input_shape_tensor.vec<int64>(),
220                                                &input_shape));
221     } else {
222       LOG(FATAL) << "shape must have type int32 or int64.";
223     }
224 
225     OP_REQUIRES_OK(
226         context,
227         ValidateStridedSliceOp(
228             &context->input(1), &context->input(2), context->input(3),
229             input_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
230             shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
231             &is_simple_slice, &slice_dim0, &begin, &end, &strides));
232 
233     // Check to make sure dy is consistent with the original slice
234     TensorShape dy_shape = context->input(4).shape();
235     OP_REQUIRES(
236         context, final_shape == dy_shape,
237         errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(),
238                                 " instead of ", final_shape.DebugString()));
239 
240     if (!context->status().ok()) return;
241 
242     // const int input_dims = input.dims();
243     const int processing_dims = processing_shape.dims();
244     Tensor* result = nullptr;
245     OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &result));
246 
247     if (processing_shape.dims() == 0) {
248       auto in = context->input(4);
249       OP_REQUIRES(context, result->CopyFrom(in, processing_shape),
250                   errors::Internal("Copy failed"));
251       return;
252     }
253 
254 #define HANDLE_DIM(NDIM)                                                      \
255   if (processing_dims == NDIM) {                                              \
256     HandleStridedSliceGradCase<Device, T, NDIM>(context, begin, end, strides, \
257                                                 processing_shape,             \
258                                                 is_simple_slice, result);     \
259     return;                                                                   \
260   }
261 
262     HANDLE_DIM(1);
263     HANDLE_DIM(2);
264     HANDLE_DIM(3);
265     HANDLE_DIM(4);
266     HANDLE_DIM(5);
267     HANDLE_DIM(6);
268     HANDLE_DIM(7);
269 
270 #undef HANDLE_DIM
271   }
272 
273  private:
274   int32 begin_mask, end_mask;
275   int32 ellipsis_mask, new_axis_mask, shrink_axis_mask;
276 };
277 
278 template <typename Device, typename T>
279 class StridedSliceAssignOp : public OpKernel {
280  public:
StridedSliceAssignOp(OpKernelConstruction * context)281   explicit StridedSliceAssignOp(OpKernelConstruction* context)
282       : OpKernel(context) {
283     OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask));
284     OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask));
285     OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask));
286     OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask));
287     OP_REQUIRES_OK(context,
288                    context->GetAttr("shrink_axis_mask", &shrink_axis_mask));
289   }
290 
Compute(OpKernelContext * context)291   void Compute(OpKernelContext* context) override {
292     TensorShape processing_shape, final_shape;
293     bool is_identity = true;
294     bool slice_dim0 = true;
295     bool is_simple_slice = true;
296     gtl::InlinedVector<int64, 4> begin;
297     gtl::InlinedVector<int64, 4> end;
298     gtl::InlinedVector<int64, 4> strides;
299 
300     Tensor* old_lhs = nullptr;
301     Tensor tmp;
302     if (context->input_dtype(0) == DT_RESOURCE) {
303       Var* v;
304       OP_REQUIRES_OK(context,
305                      LookupResource(context, HandleFromInput(context, 0), &v));
306       core::ScopedUnref scoped_unref(v);
307       OP_REQUIRES_OK(context,
308                      EnsureSparseVariableAccess<Device, T>(context, v));
309       mutex_lock ml(*v->mu());
310       old_lhs = v->tensor();
311       OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
312                   errors::InvalidArgument(
313                       "l-value dtype ", DataTypeString(old_lhs->dtype()),
314                       " does not match r-value dtype ",
315                       DataTypeString(DataTypeToEnum<T>::value)));
316     } else {
317       context->forward_ref_input_to_ref_output(0, 0);
318       tmp = context->mutable_input(0, true);
319       old_lhs = &tmp;
320     }
321 
322     OP_REQUIRES_OK(
323         context, ValidateStridedSliceOp(
324                      &context->input(1), &context->input(2), context->input(3),
325                      old_lhs->shape(), begin_mask, end_mask, ellipsis_mask,
326                      new_axis_mask, shrink_axis_mask, &processing_shape,
327                      &final_shape, &is_identity, &is_simple_slice, &slice_dim0,
328                      &begin, &end, &strides));
329 
330     if (processing_shape.num_elements()) {
331       const Tensor& input = context->input(4);
332       TensorShape input_shape = input.shape();
333       TensorShape original_shape = old_lhs->shape();
334       // TODO(aselle): This check is too strong, we only should need
335       // input_shape to be broadcastable to final_shape
336       OP_REQUIRES(
337           context, final_shape == input_shape,
338           errors::Unimplemented(
339               "sliced l-value shape ", final_shape.DebugString(),
340               " does not match r-value shape ", input_shape.DebugString(),
341               ". Automatic broadcasting not ", "yet implemented."));
342       const int processing_dims = processing_shape.dims();
343 
344       // 0-dimensional case implies the left and right are exactly the same
345       // scalar shape
346 
347 // Handle general dimensions
348 #define HANDLE_DIM(NDIM)                                                       \
349   if (processing_dims == NDIM) {                                               \
350     HandleStridedSliceAssignCase<Device, T, NDIM>()(context, begin, end,       \
351                                                     strides, processing_shape, \
352                                                     is_simple_slice, old_lhs); \
353     return;                                                                    \
354   }
355       HANDLE_DIM(0);
356       HANDLE_DIM(1);
357       HANDLE_DIM(2);
358       HANDLE_DIM(3);
359       HANDLE_DIM(4);
360       HANDLE_DIM(5);
361       HANDLE_DIM(6);
362       HANDLE_DIM(7);
363 #undef HANDLE_DIM
364 
365       OP_REQUIRES(context, false,
366                   errors::Unimplemented("Unhandled input dimensions ",
367                                         processing_dims));
368     }
369   }
370 
371  private:
372   int32 begin_mask, end_mask;
373   int32 ellipsis_mask, new_axis_mask, shrink_axis_mask;
374 };
375 
376 #define REGISTER_STRIDED_SLICE(type)                             \
377   REGISTER_KERNEL_BUILDER(Name("StridedSlice")                   \
378                               .Device(DEVICE_CPU)                \
379                               .TypeConstraint<type>("T")         \
380                               .HostMemory("begin")               \
381                               .HostMemory("end")                 \
382                               .HostMemory("strides"),            \
383                           StridedSliceOp<CPUDevice, type>)       \
384   REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")               \
385                               .Device(DEVICE_CPU)                \
386                               .TypeConstraint<type>("T")         \
387                               .HostMemory("shape")               \
388                               .HostMemory("begin")               \
389                               .HostMemory("end")                 \
390                               .HostMemory("strides"),            \
391                           StridedSliceGradOp<CPUDevice, type>)   \
392   REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")             \
393                               .Device(DEVICE_CPU)                \
394                               .TypeConstraint<type>("T")         \
395                               .HostMemory("begin")               \
396                               .HostMemory("end")                 \
397                               .HostMemory("strides"),            \
398                           StridedSliceAssignOp<CPUDevice, type>) \
399   REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")     \
400                               .Device(DEVICE_CPU)                \
401                               .TypeConstraint<type>("T")         \
402                               .HostMemory("ref")                 \
403                               .HostMemory("begin")               \
404                               .HostMemory("end")                 \
405                               .HostMemory("strides"),            \
406                           StridedSliceAssignOp<CPUDevice, type>)
407 
408 TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
409 
410 #undef REGISTER_STRIDED_SLICE
411 
412 #if GOOGLE_CUDA
413 
414 #define REGISTER_GPU(type)                                       \
415   REGISTER_KERNEL_BUILDER(Name("StridedSlice")                   \
416                               .Device(DEVICE_GPU)                \
417                               .TypeConstraint<type>("T")         \
418                               .HostMemory("begin")               \
419                               .HostMemory("end")                 \
420                               .HostMemory("strides"),            \
421                           StridedSliceOp<GPUDevice, type>)       \
422   REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")               \
423                               .Device(DEVICE_GPU)                \
424                               .TypeConstraint<type>("T")         \
425                               .HostMemory("shape")               \
426                               .HostMemory("begin")               \
427                               .HostMemory("end")                 \
428                               .HostMemory("strides"),            \
429                           StridedSliceGradOp<GPUDevice, type>)   \
430   REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")             \
431                               .Device(DEVICE_GPU)                \
432                               .TypeConstraint<type>("T")         \
433                               .HostMemory("begin")               \
434                               .HostMemory("end")                 \
435                               .HostMemory("strides"),            \
436                           StridedSliceAssignOp<GPUDevice, type>) \
437   REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")     \
438                               .Device(DEVICE_GPU)                \
439                               .TypeConstraint<type>("T")         \
440                               .HostMemory("ref")                 \
441                               .HostMemory("begin")               \
442                               .HostMemory("end")                 \
443                               .HostMemory("strides"),            \
444                           StridedSliceAssignOp<GPUDevice, type>)
445 
446 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
447 TF_CALL_bool(REGISTER_GPU);
448 TF_CALL_int8(REGISTER_GPU);
449 TF_CALL_complex64(REGISTER_GPU);
450 TF_CALL_complex128(REGISTER_GPU);
451 TF_CALL_int64(REGISTER_GPU);
452 
453 // A special GPU kernel for int32.
454 // TODO(b/25387198): Also enable int32 in device memory. This kernel
455 // registration requires all int32 inputs and outputs to be in host memory.
456 REGISTER_KERNEL_BUILDER(Name("StridedSlice")
457                             .Device(DEVICE_GPU)
458                             .TypeConstraint<int32>("T")
459                             .HostMemory("input")
460                             .HostMemory("begin")
461                             .HostMemory("end")
462                             .HostMemory("strides")
463                             .HostMemory("output"),
464                         StridedSliceOp<CPUDevice, int32>);
465 REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")
466                             .Device(DEVICE_GPU)
467                             .TypeConstraint<int32>("T")
468                             .HostMemory("shape")
469                             .HostMemory("begin")
470                             .HostMemory("end")
471                             .HostMemory("strides")
472                             .HostMemory("dy")
473                             .HostMemory("output"),
474                         StridedSliceGradOp<CPUDevice, int32>);
475 REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")
476                             .Device(DEVICE_GPU)
477                             .TypeConstraint<int32>("T")
478                             .HostMemory("ref")
479                             .HostMemory("begin")
480                             .HostMemory("end")
481                             .HostMemory("strides"),
482                         StridedSliceAssignOp<CPUDevice, int32>)
483 REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")
484                             .Device(DEVICE_GPU)
485                             .TypeConstraint<int32>("T")
486                             .HostMemory("ref")
487                             .HostMemory("begin")
488                             .HostMemory("end")
489                             .HostMemory("strides"),
490                         StridedSliceAssignOp<CPUDevice, int32>)
491 #undef REGISTER_GPU
492 
493 #endif  // GOOGLE_CUDA
494 
495 #ifdef TENSORFLOW_USE_SYCL
496 #define REGISTER_SYCL(type)                                       \
497   REGISTER_KERNEL_BUILDER(Name("StridedSlice")                    \
498                               .Device(DEVICE_SYCL)                \
499                               .TypeConstraint<type>("T")          \
500                               .HostMemory("begin")                \
501                               .HostMemory("end")                  \
502                               .HostMemory("strides"),             \
503                           StridedSliceOp<SYCLDevice, type>)       \
504   REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")                \
505                               .Device(DEVICE_SYCL)                \
506                               .TypeConstraint<type>("T")          \
507                               .HostMemory("shape")                \
508                               .HostMemory("begin")                \
509                               .HostMemory("end")                  \
510                               .HostMemory("strides"),             \
511                           StridedSliceGradOp<SYCLDevice, type>)   \
512   REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")              \
513                               .Device(DEVICE_SYCL)                \
514                               .TypeConstraint<type>("T")          \
515                               .HostMemory("begin")                \
516                               .HostMemory("end")                  \
517                               .HostMemory("strides"),             \
518                           StridedSliceAssignOp<SYCLDevice, type>) \
519   REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")      \
520                               .Device(DEVICE_SYCL)                \
521                               .TypeConstraint<type>("T")          \
522                               .HostMemory("ref")                  \
523                               .HostMemory("begin")                \
524                               .HostMemory("end")                  \
525                               .HostMemory("strides"),             \
526                           StridedSliceAssignOp<SYCLDevice, type>)
527 
528 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
529 
530 REGISTER_KERNEL_BUILDER(Name("StridedSlice")
531                             .Device(DEVICE_SYCL)
532                             .TypeConstraint<int32>("T")
533                             .HostMemory("input")
534                             .HostMemory("begin")
535                             .HostMemory("end")
536                             .HostMemory("strides")
537                             .HostMemory("output"),
538                         StridedSliceOp<CPUDevice, int32>);
539 REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")
540                             .Device(DEVICE_SYCL)
541                             .TypeConstraint<int32>("T")
542                             .HostMemory("shape")
543                             .HostMemory("begin")
544                             .HostMemory("end")
545                             .HostMemory("strides")
546                             .HostMemory("dy")
547                             .HostMemory("output"),
548                         StridedSliceGradOp<CPUDevice, int32>);
549 REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")
550                             .Device(DEVICE_SYCL)
551                             .TypeConstraint<int32>("T")
552                             .HostMemory("ref")
553                             .HostMemory("begin")
554                             .HostMemory("end")
555                             .HostMemory("strides"),
556                         StridedSliceAssignOp<CPUDevice, int32>)
557 REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")
558                             .Device(DEVICE_SYCL)
559                             .TypeConstraint<int32>("T")
560                             .HostMemory("ref")
561                             .HostMemory("begin")
562                             .HostMemory("end")
563                             .HostMemory("strides"),
564                         StridedSliceAssignOp<CPUDevice, int32>)
565 #undef REGISTER_SYCL
566 #endif  // TENSORFLOW_USE_SYCL
567 }  // namespace tensorflow
568