• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #if GOOGLE_CUDA
21 #define EIGEN_USE_GPU
22 #endif  // GOOGLE_CUDA
23 
24 #include "tensorflow/core/kernels/slice_op.h"
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/kernels/ops_util.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/gtl/array_slice.h"
33 #include "tensorflow/core/platform/prefetch.h"
34 
35 namespace tensorflow {
36 
37 namespace {
38 
IntTensorToInt64Vec(const Tensor & tensor)39 gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
40   gtl::InlinedVector<int64, 4> out;
41   if (tensor.dtype() == DT_INT32) {
42     for (int64 i = 0; i < tensor.NumElements(); ++i) {
43       out.push_back(tensor.flat<int32>()(i));
44     }
45   } else if (tensor.dtype() == DT_INT64) {
46     for (int64 i = 0; i < tensor.NumElements(); ++i) {
47       out.push_back(tensor.flat<int64>()(i));
48     }
49   } else {
50     LOG(FATAL) << "begin must be either int32 or int64";
51   }
52   return out;
53 }
54 
55 }  // namespace
56 
57 typedef Eigen::ThreadPoolDevice CPUDevice;
58 typedef Eigen::GpuDevice GPUDevice;
59 #ifdef TENSORFLOW_USE_SYCL
60 typedef Eigen::SyclDevice SYCLDevice;
61 #endif  // TENSORFLOW_USE_SYCL
62 
63 // Shared code that is not dependent on the type of T.  We do this to reduce
64 // code size by not duplicating all this for all T (float, double, int32, etc.)
SharedValidation(OpKernelContext * context,TensorShape * output_shape,bool * is_identity,bool * slice_dim0,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size)65 static void SharedValidation(OpKernelContext* context,
66                              TensorShape* output_shape, bool* is_identity,
67                              bool* slice_dim0,
68                              gtl::InlinedVector<int64, 4>* begin,
69                              gtl::InlinedVector<int64, 4>* size) {
70   const Tensor& input = context->input(0);
71   const Tensor& begin_tensor = context->input(1);
72   const Tensor& size_tensor = context->input(2);
73 
74   OP_REQUIRES(
75       context,
76       context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
77           context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
78           begin_tensor.NumElements() == input.dims() &&
79           size_tensor.NumElements() == input.dims(),
80       errors::InvalidArgument(
81           "Expected begin and size arguments to be 1-D tensors of size ",
82           input.dims(), ", but got shapes ", begin_tensor.shape().DebugString(),
83           " and ", size_tensor.shape().DebugString(), " instead."));
84 
85   const int input_dims = input.dims();
86   *begin = IntTensorToInt64Vec(begin_tensor);
87   *size = IntTensorToInt64Vec(size_tensor);
88   for (int i = 0; i < input_dims; ++i) {
89     if ((*size)[i] == -1) {
90       // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
91       (*size)[i] = input.dim_size(i) - (*begin)[i];
92     }
93   }
94 
95   *is_identity = true;
96   *slice_dim0 = true;
97   for (int i = 0; i < input_dims; ++i) {
98     int64 b = (*begin)[i];
99     int64 s = (*size)[i];
100     if (input.dim_size(i) == 0) {
101       OP_REQUIRES(
102           context, b == 0 && s == 0,
103           errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
104                                   ") and size[", i, "] == 0 ", "(got ", s,
105                                   ") when ", "input.dim_size(", i, ") == 0"));
106     } else {
107       OP_REQUIRES(context, 0 <= b && b <= input.dim_size(i),
108                   errors::InvalidArgument("Expected begin[", i, "] in [0, ",
109                                           input.dim_size(i), "], but got ", b));
110       OP_REQUIRES(
111           context, 0 <= s && b + s <= input.dim_size(i),
112           errors::InvalidArgument("Expected size[", i, "] in [0, ",
113                                   input.dim_size(i) - b, "], but ", "got ", s));
114     }
115     output_shape->AddDim(s);
116     const bool take_all = (b == 0) && (s == input.dim_size(i));
117     (*is_identity) &= take_all;
118     (*slice_dim0) &= (i == 0) || take_all;
119   }
120 }
121 
122 // Extracted out code in SliceOp::Compute so that MklSliceOp can reuse this
123 // generic code
124 template <typename T>
SharedSliceCommonCases(OpKernelContext * context,TensorShape * output_shape,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size,Tensor ** result,bool * done)125 static void SharedSliceCommonCases(OpKernelContext* context,
126                                    TensorShape* output_shape,
127                                    gtl::InlinedVector<int64, 4>* begin,
128                                    gtl::InlinedVector<int64, 4>* size,
129                                    Tensor** result, bool* done) {
130   bool is_identity = true;
131   bool slice_dim0 = true;
132   *done = false;
133 
134   SharedValidation(context, output_shape, &is_identity, &slice_dim0, begin,
135                    size);
136   if (!context->status().ok()) return;
137   const Tensor& input = context->input(0);
138   if (is_identity) {
139     VLOG(1) << "Slice identity";
140     context->set_output(0, input);
141     *done = true;
142     return;
143   }
144 
145   if (slice_dim0 &&
146       IsDim0SliceAligned<T>(input.shape(), (*begin)[0], (*size)[0])) {
147     VLOG(1) << "Slice dim 0: " << input.shape().DebugString();
148     CHECK_GE(input.dims(), 1);  // Otherwise, is_identity should be true.
149     context->set_output(0, input.Slice((*begin)[0], (*begin)[0] + (*size)[0]));
150     *done = true;
151     return;
152   }
153 
154   OP_REQUIRES_OK(context, context->allocate_output(0, *output_shape, result));
155 }
156 
157 template <typename Device, typename T>
158 class SliceOp : public OpKernel {
159  public:
SliceOp(OpKernelConstruction * context)160   explicit SliceOp(OpKernelConstruction* context) : OpKernel(context) {}
161 
Compute(OpKernelContext * context)162   void Compute(OpKernelContext* context) override {
163     TensorShape output_shape;
164     gtl::InlinedVector<int64, 4> begin;
165     gtl::InlinedVector<int64, 4> size;
166     Tensor* result = nullptr;
167     bool done = false;
168     SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
169                               &done);
170     if (!context->status().ok() || done == true) return;
171 
172     const Tensor& input = context->input(0);
173     const int input_dims = input.dims();
174 
175     if (output_shape.num_elements() > 0) {
176       if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
177           DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
178         auto input = context->input(0).tensor<T, 2>();
179         auto output = result->tensor<T, 2>();
180         // TODO(agarwal): Consider multi-threading this loop for cases where
181         // size[0] is very large.
182         for (int i = 0; i < size[0]; ++i) {
183           const int64 row = begin[0] + i;
184           if (i + 1 < size[0]) {
185             port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0));
186             port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1]));
187           }
188           memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T));
189         }
190         return;
191       }
192 #define HANDLE_DIM(NDIM)                            \
193   if (input_dims == NDIM) {                         \
194     HandleCase<NDIM>(context, begin, size, result); \
195     return;                                         \
196   }
197 
198       HANDLE_DIM(1);
199       HANDLE_DIM(2);
200       HANDLE_DIM(3);
201       HANDLE_DIM(4);
202       HANDLE_DIM(5);
203       HANDLE_DIM(6);
204       HANDLE_DIM(7);
205 
206 #undef HANDLE_DIM
207 
208       OP_REQUIRES(
209           context, false,
210           errors::Unimplemented("SliceOp : Unhandled input dimensions"));
211     }
212   }
213 
214  private:
215   template <int NDIM>
HandleCase(OpKernelContext * context,const gtl::ArraySlice<int64> & begin,const gtl::ArraySlice<int64> & size,Tensor * result)216   void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
217                   const gtl::ArraySlice<int64>& size, Tensor* result) {
218     Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
219     Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
220     for (int i = 0; i < NDIM; ++i) {
221       indices[i] = begin[i];
222       sizes[i] = size[i];
223     }
224 
225     functor::Slice<Device, T, NDIM>()(
226         context->eigen_device<Device>(), result->tensor<T, NDIM>(),
227         context->input(0).tensor<T, NDIM>(), indices, sizes);
228   }
229 };
230 
231 // Forward declarations of the functor specializations for declared in the
232 // sharded source files.
233 namespace functor {
234 #define DECLARE_CPU_SPEC(T, NDIM)                                  \
235   template <>                                                      \
236   void Slice<CPUDevice, T, NDIM>::operator()(                      \
237       const CPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
238       typename TTypes<T, NDIM>::ConstTensor input,                 \
239       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,       \
240       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes);        \
241   extern template struct Slice<CPUDevice, T, NDIM>;
242 
243 #define DECLARE_FOR_N(T)  \
244   DECLARE_CPU_SPEC(T, 1); \
245   DECLARE_CPU_SPEC(T, 2); \
246   DECLARE_CPU_SPEC(T, 3); \
247   DECLARE_CPU_SPEC(T, 4); \
248   DECLARE_CPU_SPEC(T, 5); \
249   DECLARE_CPU_SPEC(T, 6); \
250   DECLARE_CPU_SPEC(T, 7);
251 
252 TF_CALL_ALL_TYPES(DECLARE_FOR_N);
253 
254 #undef DECLARE_FOR_N
255 #undef DECLARE_CPU_SPEC
256 }  // namespace functor
257 
258 #define REGISTER_SLICE(type)                             \
259   REGISTER_KERNEL_BUILDER(Name("Slice")                  \
260                               .Device(DEVICE_CPU)        \
261                               .TypeConstraint<type>("T") \
262                               .HostMemory("begin")       \
263                               .HostMemory("size"),       \
264                           SliceOp<CPUDevice, type>)
265 
266 TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
267 TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
268 #undef REGISTER_SLICE
269 
270 #if GOOGLE_CUDA
271 // Forward declarations of the functor specializations for GPU.
272 namespace functor {
273 #define DECLARE_GPU_SPEC(T, NDIM)                                  \
274   template <>                                                      \
275   void Slice<GPUDevice, T, NDIM>::operator()(                      \
276       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
277       typename TTypes<T, NDIM>::ConstTensor input,                 \
278       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,       \
279       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes);        \
280   extern template struct Slice<GPUDevice, T, NDIM>;
281 
282 #define DECLARE_FOR_N(T)  \
283   DECLARE_GPU_SPEC(T, 1); \
284   DECLARE_GPU_SPEC(T, 2); \
285   DECLARE_GPU_SPEC(T, 3); \
286   DECLARE_GPU_SPEC(T, 4); \
287   DECLARE_GPU_SPEC(T, 5); \
288   DECLARE_GPU_SPEC(T, 6); \
289   DECLARE_GPU_SPEC(T, 7);
290 
291 TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N);
292 TF_CALL_complex64(DECLARE_FOR_N);
293 TF_CALL_complex128(DECLARE_FOR_N);
294 TF_CALL_bfloat16(DECLARE_FOR_N);
295 TF_CALL_bool(DECLARE_FOR_N);
296 TF_CALL_int8(DECLARE_FOR_N);
297 TF_CALL_int64(DECLARE_FOR_N);
298 DECLARE_FOR_N(int32);
299 
300 #undef DECLARE_FOR_N
301 #undef DECLARE_GPU_SPEC
302 }  // namespace functor
303 
304 #define REGISTER_GPU(type)                               \
305   REGISTER_KERNEL_BUILDER(Name("Slice")                  \
306                               .Device(DEVICE_GPU)        \
307                               .TypeConstraint<type>("T") \
308                               .HostMemory("begin")       \
309                               .HostMemory("size"),       \
310                           SliceOp<GPUDevice, type>)
311 
312 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
313 TF_CALL_complex64(REGISTER_GPU);
314 TF_CALL_complex128(REGISTER_GPU);
315 TF_CALL_bfloat16(REGISTER_GPU);
316 TF_CALL_bool(REGISTER_GPU);
317 TF_CALL_int8(REGISTER_GPU);
318 TF_CALL_int64(REGISTER_GPU);
319 
320 // A special GPU kernel for int32.
321 // TODO(b/25387198): Also enable int32 in device memory. This kernel
322 // registration requires all int32 inputs and outputs to be in host memory.
323 REGISTER_KERNEL_BUILDER(Name("Slice")
324                             .Device(DEVICE_GPU)
325                             .TypeConstraint<int32>("T")
326                             .HostMemory("input")
327                             .HostMemory("begin")
328                             .HostMemory("size")
329                             .HostMemory("output"),
330                         SliceOp<CPUDevice, int32>);
331 
332 #undef REGISTER_GPU
333 
334 #endif  // GOOGLE_CUDA
335 
336 #ifdef TENSORFLOW_USE_SYCL
337 // Forward declarations of the functor specializations for SYCL.
338 namespace functor {
339 #define DECLARE_SYCL_SPEC(T, NDIM)                                  \
340   template <>                                                       \
341   void Slice<SYCLDevice, T, NDIM>::operator()(                      \
342       const SYCLDevice& d, typename TTypes<T, NDIM>::Tensor output, \
343       typename TTypes<T, NDIM>::ConstTensor input,                  \
344       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,        \
345       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes);         \
346   extern template struct Slice<SYCLDevice, T, NDIM>;
347 
348 #define DECLARE_FOR_N(T)   \
349   DECLARE_SYCL_SPEC(T, 1); \
350   DECLARE_SYCL_SPEC(T, 2); \
351   DECLARE_SYCL_SPEC(T, 3); \
352   DECLARE_SYCL_SPEC(T, 4); \
353   DECLARE_SYCL_SPEC(T, 5); \
354   DECLARE_SYCL_SPEC(T, 6); \
355   DECLARE_SYCL_SPEC(T, 7);
356 
357 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N);
358 DECLARE_FOR_N(int32);
359 DECLARE_FOR_N(bool);
360 
361 #undef DECLARE_FOR_N
362 #undef DECLARE_SYCL_SPEC
363 }  // namespace functor
364 
365 #define REGISTER_SYCL(type)                                    \
366   REGISTER_KERNEL_BUILDER(Name("Slice")                        \
367                               .Device(DEVICE_SYCL)             \
368                               .TypeConstraint<type>("T")       \
369                               .HostMemory("begin")             \
370                               .HostMemory("size")              \
371                               .TypeConstraint<int32>("Index"), \
372                           SliceOp<SYCLDevice, type>)
373 
374 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
375 
376 REGISTER_KERNEL_BUILDER(Name("Slice")
377                             .Device(DEVICE_SYCL)
378                             .TypeConstraint<int32>("T")
379                             .TypeConstraint<int32>("Index")
380                             .HostMemory("input")
381                             .HostMemory("begin")
382                             .HostMemory("size")
383                             .HostMemory("output"),
384                         SliceOp<CPUDevice, int32>);
385 #undef REGISTER_SYCL
386 
387 #endif  // TENSORFLOW_USE_SYCL
388 }  // namespace tensorflow
389