1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
17 #define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
18
19 // Functor definition for StridedSliceOp, must be compilable by nvcc.
20
21 #include "tensorflow/core/kernels/slice_op.h"
22 #include "tensorflow/core/kernels/strided_slice_op.h"
23
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/register_types_traits.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/variant.h"
31 #include "tensorflow/core/framework/variant_encode_decode.h"
32 #include "tensorflow/core/kernels/dense_update_functor.h"
33 #include "tensorflow/core/kernels/ops_util.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/gtl/array_slice.h"
36 #include "tensorflow/core/platform/mem.h"
37
38 namespace tensorflow {
39
40 template <typename Device, typename T, int NDIM>
41 void HandleStridedSliceCase(OpKernelContext* context,
42 const gtl::ArraySlice<int64>& begin,
43 const gtl::ArraySlice<int64>& end,
44 const gtl::ArraySlice<int64>& strides,
45 const TensorShape& processing_shape,
46 bool is_simple_slice, Tensor* result);
47
48 template <typename Device, typename T, int NDIM>
49 void HandleStridedSliceGradCase(OpKernelContext* context,
50 const gtl::ArraySlice<int64>& begin,
51 const gtl::ArraySlice<int64>& end,
52 const gtl::ArraySlice<int64>& strides,
53 const TensorShape& processing_shape,
54 bool is_simple_slice, Tensor* result);
55
56 template <typename Device, typename T, int NDIM>
57 class HandleStridedSliceAssignCase {
58 public:
59 void operator()(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
60 const gtl::ArraySlice<int64>& end,
61 const gtl::ArraySlice<int64>& strides,
62 const TensorShape& processing_shape, bool is_simple_slice,
63 Tensor* result);
64 };
65 } // namespace tensorflow
66
67 // The actual implementation. This is designed so multiple
68 // translation units can include this file in the form
69 //
70 // #define STRIDED_SLICE_INSTANTIATE_DIM 1
71 // #include <thisfile>
72 // #undef STRIDED_SLICE_INSTANTIATE_DIM
73 //
74 #ifdef STRIDED_SLICE_INSTANTIATE_DIM
75
76 namespace tensorflow {
77
78 template <typename Device, typename T, int NDIM>
HandleStridedSliceCase(OpKernelContext * context,const gtl::ArraySlice<int64> & begin,const gtl::ArraySlice<int64> & end,const gtl::ArraySlice<int64> & strides,const TensorShape & processing_shape,bool is_simple_slice,Tensor * result)79 void HandleStridedSliceCase(OpKernelContext* context,
80 const gtl::ArraySlice<int64>& begin,
81 const gtl::ArraySlice<int64>& end,
82 const gtl::ArraySlice<int64>& strides,
83 const TensorShape& processing_shape,
84 bool is_simple_slice, Tensor* result) {
85 typedef typename proxy_type<Device, T>::type Proxy;
86
87 gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes();
88 if (is_simple_slice) {
89 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
90 Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes_di;
91 for (int i = 0; i < NDIM; ++i) {
92 begin_di[i] = begin[i];
93 sizes_di[i] = end[i] - begin[i];
94 }
95 functor::Slice<Device, Proxy, NDIM>()(
96 context->eigen_device<Device>(),
97 result->bit_casted_shaped<Proxy, NDIM>(processing_dims),
98 context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, sizes_di);
99 } else {
100 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
101 Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
102 Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
103 for (int i = 0; i < NDIM; ++i) {
104 begin_di[i] = begin[i];
105 end_di[i] = end[i];
106 strides_di[i] = strides[i];
107 }
108 functor::StridedSlice<Device, Proxy, NDIM>()(
109 context->eigen_device<Device>(),
110 result->bit_casted_shaped<Proxy, NDIM>(processing_dims),
111 context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, end_di,
112 strides_di);
113 }
114 }
115
116 template <typename Device, typename T, int NDIM>
HandleStridedSliceGradCase(OpKernelContext * context,const gtl::ArraySlice<int64> & begin,const gtl::ArraySlice<int64> & end,const gtl::ArraySlice<int64> & strides,const TensorShape & processing_shape,bool is_simple_slice,Tensor * result)117 void HandleStridedSliceGradCase(OpKernelContext* context,
118 const gtl::ArraySlice<int64>& begin,
119 const gtl::ArraySlice<int64>& end,
120 const gtl::ArraySlice<int64>& strides,
121 const TensorShape& processing_shape,
122 bool is_simple_slice, Tensor* result) {
123 gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes();
124
125 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
126 Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
127 Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
128 for (int i = 0; i < NDIM; ++i) {
129 begin_di[i] = begin[i];
130 end_di[i] = end[i];
131 strides_di[i] = strides[i];
132 }
133
134 typedef typename proxy_type<Device, T>::type Proxy;
135 functor::StridedSliceGrad<Device, Proxy, NDIM>()(
136 context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(),
137 context->input(4).bit_casted_shaped<Proxy, NDIM>(processing_dims),
138 begin_di, end_di, strides_di);
139 }
140
141 template <typename Device, typename T, int NDIM>
operator()142 void HandleStridedSliceAssignCase<Device, T, NDIM>::operator()(
143 OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
144 const gtl::ArraySlice<int64>& end, const gtl::ArraySlice<int64>& strides,
145 const TensorShape& processing_shape, bool is_simple_slice, Tensor* result) {
146 gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes();
147 typedef typename proxy_type<Device, T>::type Proxy;
148 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
149 Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
150 Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
151 for (int i = 0; i < NDIM; ++i) {
152 begin_di[i] = begin[i];
153 end_di[i] = end[i];
154 strides_di[i] = strides[i];
155 }
156 functor::StridedSliceAssign<Device, Proxy, NDIM>()(
157 context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(),
158 context->input(4).bit_casted_shaped<Proxy, NDIM>(processing_dims),
159 begin_di, end_di, strides_di);
160 }
161
162 template <typename Device, typename T>
163 class HandleStridedSliceAssignCase<Device, T, 0> {
164 public:
165 enum { NDIM_PROXY = 1 };
operator()166 void operator()(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
167 const gtl::ArraySlice<int64>& end,
168 const gtl::ArraySlice<int64>& strides,
169 const TensorShape& processing_shape, bool is_simple_slice,
170 Tensor* result) {
171 gtl::InlinedVector<int64, 1> processing_dims(1);
172 processing_dims[0] = 1;
173
174 typedef typename proxy_type<Device, T>::type Proxy;
175 functor::StridedSliceAssignScalar<Device, Proxy>()(
176 context->eigen_device<Device>(),
177 result->bit_casted_shaped<Proxy, 1>(processing_dims),
178 context->input(4).bit_casted_shaped<Proxy, 1>(processing_dims));
179 }
180 };
181
182 // NOTE(aselle): according to bsteiner, we need this because otherwise
183 // nvcc instantiates templates that are invalid. strided_slice_op_gpu.cu
184 // handles instantiates externally. It is important that this is done
185 // before the HandleXXCase's are instantiated to avoid duplicate
186 // specialization errors.
187
188 #define PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM) \
189 namespace functor { \
190 template <> \
191 void StridedSlice<GPUDevice, T, NDIM>::operator()( \
192 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
193 typename TTypes<T, NDIM>::ConstTensor input, \
194 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \
195 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \
196 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides); \
197 extern template struct StridedSlice<GPUDevice, T, NDIM>; \
198 template <> \
199 void Slice<GPUDevice, T, NDIM>::operator()( \
200 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
201 typename TTypes<T, NDIM>::ConstTensor input, \
202 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
203 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
204 extern template struct Slice<GPUDevice, T, NDIM>; \
205 template <> \
206 void StridedSliceGrad<GPUDevice, T, NDIM>::operator()( \
207 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
208 typename TTypes<T, NDIM>::ConstTensor input, \
209 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \
210 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \
211 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides); \
212 extern template struct StridedSliceGrad<GPUDevice, T, NDIM>; \
213 template <> \
214 void StridedSliceAssign<GPUDevice, T, NDIM>::operator()( \
215 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
216 typename TTypes<T, NDIM>::ConstTensor input, \
217 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \
218 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \
219 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides); \
220 extern template struct StridedSliceAssign<GPUDevice, T, NDIM>; \
221 } // namespace functor
222 #define PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM) \
223 namespace functor { \
224 template <> \
225 void StridedSliceAssignScalar<GPUDevice, T>::operator()( \
226 const GPUDevice& d, typename TTypes<T, 1>::Tensor output, \
227 typename TTypes<T, 1>::ConstTensor input); \
228 extern template struct StridedSliceAssignScalar<GPUDevice, T>; \
229 } // namespace functor
230
231 // Dimension 0 only instantiates some functors. So we only need
232 // to prevent ones defined by PREVENT_INSTANTIATE_DIM0_ONLY
233 #if GOOGLE_CUDA
234 #if STRIDED_SLICE_INSTANTIATE_DIM == 0
235 #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM)
236 #else
237 #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM)
238 #endif
239 #else
240 #define PREVENT_INSTANTIATE(T, NDIM)
241 #endif
242
243 #define INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM) \
244 template void HandleStridedSliceCase<DEVICE, T, DIM>( \
245 OpKernelContext * context, const gtl::ArraySlice<int64>& begin, \
246 const gtl::ArraySlice<int64>& end, \
247 const gtl::ArraySlice<int64>& strides, \
248 const TensorShape& processing_shape, bool is_simple_slice, \
249 Tensor* result); \
250 template void HandleStridedSliceGradCase<DEVICE, T, DIM>( \
251 OpKernelContext * context, const gtl::ArraySlice<int64>& begin, \
252 const gtl::ArraySlice<int64>& end, \
253 const gtl::ArraySlice<int64>& strides, \
254 const TensorShape& processing_shape, bool is_simple_slice, \
255 Tensor* result);
256
257 #define INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \
258 template class HandleStridedSliceAssignCase<DEVICE, T, DIM>;
259
260 // Only some kernels need to be instantiated on dim 0.
261 #if STRIDED_SLICE_INSTANTIATE_DIM == 0
262 #define INSTANTIATE(DEVICE, T, DIM) \
263 INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM)
264 #else
265 #define INSTANTIATE(DEVICE, T, DIM) \
266 INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \
267 INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM)
268 #endif
269
270 #define DECLARE_FOR_N_CPU(T) \
271 INSTANTIATE(CPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
272
273 #define PREVENT_FOR_N_GPU(T) \
274 PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM)
275
276 #define DECLARE_FOR_N_GPU(T) \
277 INSTANTIATE(GPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
278
279 #if GOOGLE_CUDA
280 TF_CALL_GPU_PROXY_TYPES(PREVENT_FOR_N_GPU);
281 TF_CALL_complex64(PREVENT_FOR_N_GPU);
282 TF_CALL_complex128(PREVENT_FOR_N_GPU);
283
284 TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_GPU);
285 TF_CALL_complex64(DECLARE_FOR_N_GPU);
286 TF_CALL_complex128(DECLARE_FOR_N_GPU);
287 TF_CALL_bool(DECLARE_FOR_N_GPU);
288 TF_CALL_int8(DECLARE_FOR_N_GPU);
289 DECLARE_FOR_N_GPU(int32);
290 DECLARE_FOR_N_GPU(int64);
291 #endif // END GOOGLE_CUDA
292
293 TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
294
295 #ifdef TENSORFLOW_USE_SYCL
296 #define PREVENT_FOR_N_SYCL(T) \
297 PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM)
298
299 #define DECLARE_FOR_N_SYCL(T) \
300 INSTANTIATE(SYCLDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
301
302 TF_CALL_SYCL_PROXY_TYPES(PREVENT_FOR_N_SYCL);
303 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N_SYCL);
304 DECLARE_FOR_N_SYCL(int32);
305 DECLARE_FOR_N_SYCL(int64);
306
307 #undef DECLARE_FOR_N_SYCL
308 #endif // TENSORFLOW_USE_SYCL
309
310 #undef INSTANTIATE
311 #undef DECLARE_FOR_N_CPU
312 #undef DECLARE_FOR_N_GPU
313
314 } // end namespace tensorflow
315
316 #endif // END STRIDED_SLICE_INSTANTIATE_DIM
317 #endif // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
318