• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
17 #define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
18 
19 // Functor definition for StridedSliceOp, must be compilable by nvcc.
20 
21 #include "tensorflow/core/kernels/slice_op.h"
22 #include "tensorflow/core/kernels/strided_slice_op.h"
23 
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/register_types_traits.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/variant.h"
31 #include "tensorflow/core/framework/variant_encode_decode.h"
32 #include "tensorflow/core/kernels/dense_update_functor.h"
33 #include "tensorflow/core/kernels/ops_util.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/gtl/array_slice.h"
36 #include "tensorflow/core/platform/mem.h"
37 
38 namespace tensorflow {
39 
40 template <typename Device, typename T, int NDIM>
41 void HandleStridedSliceCase(OpKernelContext* context,
42                             const gtl::ArraySlice<int64>& begin,
43                             const gtl::ArraySlice<int64>& end,
44                             const gtl::ArraySlice<int64>& strides,
45                             const TensorShape& processing_shape,
46                             bool is_simple_slice, Tensor* result);
47 
48 template <typename Device, typename T, int NDIM>
49 void HandleStridedSliceGradCase(OpKernelContext* context,
50                                 const gtl::ArraySlice<int64>& begin,
51                                 const gtl::ArraySlice<int64>& end,
52                                 const gtl::ArraySlice<int64>& strides,
53                                 const TensorShape& processing_shape,
54                                 bool is_simple_slice, Tensor* result);
55 
56 template <typename Device, typename T, int NDIM>
57 class HandleStridedSliceAssignCase {
58  public:
59   void operator()(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
60                   const gtl::ArraySlice<int64>& end,
61                   const gtl::ArraySlice<int64>& strides,
62                   const TensorShape& processing_shape, bool is_simple_slice,
63                   Tensor* result);
64 };
65 }  // namespace tensorflow
66 
67 // The actual implementation. This is designed so multiple
68 // translation units can include this file in the form
69 //
70 // #define STRIDED_SLICE_INSTANTIATE_DIM 1
71 // #include <thisfile>
72 // #undef STRIDED_SLICE_INSTANTIATE_DIM
73 //
74 #ifdef STRIDED_SLICE_INSTANTIATE_DIM
75 
76 namespace tensorflow {
77 
78 template <typename Device, typename T, int NDIM>
HandleStridedSliceCase(OpKernelContext * context,const gtl::ArraySlice<int64> & begin,const gtl::ArraySlice<int64> & end,const gtl::ArraySlice<int64> & strides,const TensorShape & processing_shape,bool is_simple_slice,Tensor * result)79 void HandleStridedSliceCase(OpKernelContext* context,
80                             const gtl::ArraySlice<int64>& begin,
81                             const gtl::ArraySlice<int64>& end,
82                             const gtl::ArraySlice<int64>& strides,
83                             const TensorShape& processing_shape,
84                             bool is_simple_slice, Tensor* result) {
85   typedef typename proxy_type<Device, T>::type Proxy;
86 
87   gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes();
88   if (is_simple_slice) {
89     Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
90     Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes_di;
91     for (int i = 0; i < NDIM; ++i) {
92       begin_di[i] = begin[i];
93       sizes_di[i] = end[i] - begin[i];
94     }
95     functor::Slice<Device, Proxy, NDIM>()(
96         context->eigen_device<Device>(),
97         result->bit_casted_shaped<Proxy, NDIM>(processing_dims),
98         context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, sizes_di);
99   } else {
100     Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
101     Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
102     Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
103     for (int i = 0; i < NDIM; ++i) {
104       begin_di[i] = begin[i];
105       end_di[i] = end[i];
106       strides_di[i] = strides[i];
107     }
108     functor::StridedSlice<Device, Proxy, NDIM>()(
109         context->eigen_device<Device>(),
110         result->bit_casted_shaped<Proxy, NDIM>(processing_dims),
111         context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, end_di,
112         strides_di);
113   }
114 }
115 
116 template <typename Device, typename T, int NDIM>
HandleStridedSliceGradCase(OpKernelContext * context,const gtl::ArraySlice<int64> & begin,const gtl::ArraySlice<int64> & end,const gtl::ArraySlice<int64> & strides,const TensorShape & processing_shape,bool is_simple_slice,Tensor * result)117 void HandleStridedSliceGradCase(OpKernelContext* context,
118                                 const gtl::ArraySlice<int64>& begin,
119                                 const gtl::ArraySlice<int64>& end,
120                                 const gtl::ArraySlice<int64>& strides,
121                                 const TensorShape& processing_shape,
122                                 bool is_simple_slice, Tensor* result) {
123   gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes();
124 
125   Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
126   Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
127   Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
128   for (int i = 0; i < NDIM; ++i) {
129     begin_di[i] = begin[i];
130     end_di[i] = end[i];
131     strides_di[i] = strides[i];
132   }
133 
134   typedef typename proxy_type<Device, T>::type Proxy;
135   functor::StridedSliceGrad<Device, Proxy, NDIM>()(
136       context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(),
137       context->input(4).bit_casted_shaped<Proxy, NDIM>(processing_dims),
138       begin_di, end_di, strides_di);
139 }
140 
141 template <typename Device, typename T, int NDIM>
operator()142 void HandleStridedSliceAssignCase<Device, T, NDIM>::operator()(
143     OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
144     const gtl::ArraySlice<int64>& end, const gtl::ArraySlice<int64>& strides,
145     const TensorShape& processing_shape, bool is_simple_slice, Tensor* result) {
146   gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes();
147   typedef typename proxy_type<Device, T>::type Proxy;
148   Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
149   Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
150   Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
151   for (int i = 0; i < NDIM; ++i) {
152     begin_di[i] = begin[i];
153     end_di[i] = end[i];
154     strides_di[i] = strides[i];
155   }
156   functor::StridedSliceAssign<Device, Proxy, NDIM>()(
157       context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(),
158       context->input(4).bit_casted_shaped<Proxy, NDIM>(processing_dims),
159       begin_di, end_di, strides_di);
160 }
161 
162 template <typename Device, typename T>
163 class HandleStridedSliceAssignCase<Device, T, 0> {
164  public:
165   enum { NDIM_PROXY = 1 };
operator()166   void operator()(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
167                   const gtl::ArraySlice<int64>& end,
168                   const gtl::ArraySlice<int64>& strides,
169                   const TensorShape& processing_shape, bool is_simple_slice,
170                   Tensor* result) {
171     gtl::InlinedVector<int64, 1> processing_dims(1);
172     processing_dims[0] = 1;
173 
174     typedef typename proxy_type<Device, T>::type Proxy;
175     functor::StridedSliceAssignScalar<Device, Proxy>()(
176         context->eigen_device<Device>(),
177         result->bit_casted_shaped<Proxy, 1>(processing_dims),
178         context->input(4).bit_casted_shaped<Proxy, 1>(processing_dims));
179   }
180 };
181 
182 // NOTE(aselle): according to bsteiner, we need this because otherwise
183 // nvcc instantiates templates that are invalid. strided_slice_op_gpu.cu
184 // handles instantiates externally. It is important that this is done
185 // before the HandleXXCase's are instantiated to avoid duplicate
186 // specialization errors.
187 
188 #define PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM)                   \
189   namespace functor {                                              \
190   template <>                                                      \
191   void StridedSlice<GPUDevice, T, NDIM>::operator()(               \
192       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
193       typename TTypes<T, NDIM>::ConstTensor input,                 \
194       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start,         \
195       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop,          \
196       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides);      \
197   extern template struct StridedSlice<GPUDevice, T, NDIM>;         \
198   template <>                                                      \
199   void Slice<GPUDevice, T, NDIM>::operator()(                      \
200       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
201       typename TTypes<T, NDIM>::ConstTensor input,                 \
202       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,       \
203       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes);        \
204   extern template struct Slice<GPUDevice, T, NDIM>;                \
205   template <>                                                      \
206   void StridedSliceGrad<GPUDevice, T, NDIM>::operator()(           \
207       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
208       typename TTypes<T, NDIM>::ConstTensor input,                 \
209       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start,         \
210       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop,          \
211       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides);      \
212   extern template struct StridedSliceGrad<GPUDevice, T, NDIM>;     \
213   template <>                                                      \
214   void StridedSliceAssign<GPUDevice, T, NDIM>::operator()(         \
215       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
216       typename TTypes<T, NDIM>::ConstTensor input,                 \
217       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start,         \
218       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop,          \
219       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides);      \
220   extern template struct StridedSliceAssign<GPUDevice, T, NDIM>;   \
221   }  // namespace functor
222 #define PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM)                   \
223   namespace functor {                                            \
224   template <>                                                    \
225   void StridedSliceAssignScalar<GPUDevice, T>::operator()(       \
226       const GPUDevice& d, typename TTypes<T, 1>::Tensor output,  \
227       typename TTypes<T, 1>::ConstTensor input);                 \
228   extern template struct StridedSliceAssignScalar<GPUDevice, T>; \
229   }  // namespace functor
230 
231 // Dimension 0 only instantiates some functors. So we only need
232 // to prevent ones defined by PREVENT_INSTANTIATE_DIM0_ONLY
233 #if GOOGLE_CUDA
234 #if STRIDED_SLICE_INSTANTIATE_DIM == 0
235 #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM)
236 #else
237 #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM)
238 #endif
239 #else
240 #define PREVENT_INSTANTIATE(T, NDIM)
241 #endif
242 
243 #define INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM)              \
244   template void HandleStridedSliceCase<DEVICE, T, DIM>(               \
245       OpKernelContext * context, const gtl::ArraySlice<int64>& begin, \
246       const gtl::ArraySlice<int64>& end,                              \
247       const gtl::ArraySlice<int64>& strides,                          \
248       const TensorShape& processing_shape, bool is_simple_slice,      \
249       Tensor* result);                                                \
250   template void HandleStridedSliceGradCase<DEVICE, T, DIM>(           \
251       OpKernelContext * context, const gtl::ArraySlice<int64>& begin, \
252       const gtl::ArraySlice<int64>& end,                              \
253       const gtl::ArraySlice<int64>& strides,                          \
254       const TensorShape& processing_shape, bool is_simple_slice,      \
255       Tensor* result);
256 
257 #define INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \
258   template class HandleStridedSliceAssignCase<DEVICE, T, DIM>;
259 
260 // Only some kernels need to be instantiated on dim 0.
261 #if STRIDED_SLICE_INSTANTIATE_DIM == 0
262 #define INSTANTIATE(DEVICE, T, DIM) \
263   INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM)
264 #else
265 #define INSTANTIATE(DEVICE, T, DIM)                \
266   INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \
267   INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM)
268 #endif
269 
270 #define DECLARE_FOR_N_CPU(T) \
271   INSTANTIATE(CPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
272 
273 #define PREVENT_FOR_N_GPU(T) \
274   PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM)
275 
276 #define DECLARE_FOR_N_GPU(T) \
277   INSTANTIATE(GPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
278 
279 #if GOOGLE_CUDA
280 TF_CALL_GPU_PROXY_TYPES(PREVENT_FOR_N_GPU);
281 TF_CALL_complex64(PREVENT_FOR_N_GPU);
282 TF_CALL_complex128(PREVENT_FOR_N_GPU);
283 
284 TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_GPU);
285 TF_CALL_complex64(DECLARE_FOR_N_GPU);
286 TF_CALL_complex128(DECLARE_FOR_N_GPU);
287 TF_CALL_bool(DECLARE_FOR_N_GPU);
288 TF_CALL_int8(DECLARE_FOR_N_GPU);
289 DECLARE_FOR_N_GPU(int32);
290 DECLARE_FOR_N_GPU(int64);
291 #endif  // END GOOGLE_CUDA
292 
293 TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
294 
295 #ifdef TENSORFLOW_USE_SYCL
296 #define PREVENT_FOR_N_SYCL(T) \
297   PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM)
298 
299 #define DECLARE_FOR_N_SYCL(T) \
300   INSTANTIATE(SYCLDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
301 
302 TF_CALL_SYCL_PROXY_TYPES(PREVENT_FOR_N_SYCL);
303 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N_SYCL);
304 DECLARE_FOR_N_SYCL(int32);
305 DECLARE_FOR_N_SYCL(int64);
306 
307 #undef DECLARE_FOR_N_SYCL
308 #endif  // TENSORFLOW_USE_SYCL
309 
310 #undef INSTANTIATE
311 #undef DECLARE_FOR_N_CPU
312 #undef DECLARE_FOR_N_GPU
313 
314 }  // end namespace tensorflow
315 
316 #endif  // END STRIDED_SLICE_INSTANTIATE_DIM
317 #endif  // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
318