• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/data_flow_ops.cc.
17 
18 #include "tensorflow/core/framework/bounds_check.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/core/threadpool.h"
23 
24 #ifdef GOOGLE_CUDA
25 #include "tensorflow/core/kernels/gpu_device_array.h"
26 #endif  // GOOGLE_CUDA
27 
28 namespace tensorflow {
29 
30 typedef Eigen::ThreadPoolDevice CPUDevice;
31 #ifdef GOOGLE_CUDA
32 typedef Eigen::GpuDevice GPUDevice;
33 #endif  // GOOGLE_CUDA
34 
35 template <class T>
36 class DynamicStitchOpImplBase : public OpKernel {
37  public:
DynamicStitchOpImplBase(OpKernelConstruction * c,const string & op_name)38   explicit DynamicStitchOpImplBase(OpKernelConstruction* c,
39                                    const string& op_name)
40       : OpKernel(c) {
41     // Compute expected input signature
42     const DataType dt = DataTypeToEnum<T>::v();
43     const int n = c->num_inputs() / 2;
44     DataTypeVector expected;
45     for (int i = 0; i < n; i++) {
46       expected.push_back(DT_INT32);
47     }
48     for (int i = 0; i < n; i++) {
49       expected.push_back(dt);
50     }
51     OP_REQUIRES_OK(c, c->MatchSignature(expected, {dt}));
52     OP_REQUIRES(c, c->num_inputs() > 0,
53                 errors::InvalidArgument(op_name + ": Must have some inputs"));
54     OP_REQUIRES(c, c->num_inputs() % 2 == 0,
55                 errors::InvalidArgument(
56                     op_name + ": Must have even number of arguments"));
57   }
58 
59  protected:
60   // Check if data0.shape[indices0.dims():] == data1.shape[indices1.dims():]
SameExtraShape(const Tensor & data0,const Tensor & indices0,const Tensor & data1,const Tensor & indices1)61   static bool SameExtraShape(const Tensor& data0, const Tensor& indices0,
62                              const Tensor& data1, const Tensor& indices1) {
63     const int extra0 = data0.dims() - indices0.dims();
64     const int extra1 = data1.dims() - indices1.dims();
65     if (extra0 != extra1) return false;
66     for (int i = 0; i < extra0; i++) {
67       if (data0.dim_size(indices0.dims() + i) !=
68           data1.dim_size(indices1.dims() + i)) {
69         return false;
70       }
71     }
72     return true;
73   }
74 
CheckArgsAndAllocateResult(OpKernelContext * c,OpInputList * indices_inputs,OpInputList * data_inputs,int * first_dim_size,int * data_elements_size,Tensor ** result_ptr)75   void CheckArgsAndAllocateResult(OpKernelContext* c,
76                                   OpInputList* indices_inputs,
77                                   OpInputList* data_inputs, int* first_dim_size,
78                                   int* data_elements_size,
79                                   Tensor** result_ptr) {
80     // Find maximum index in the indices vectors
81     OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs));
82 
83     int32 max_index = -1;
84     if (data_elements_size) {
85       *data_elements_size = 0;
86     }
87     for (const Tensor& indices : *indices_inputs) {
88       if (indices.NumElements() > 0) {
89         Eigen::Tensor<int32, 0, Eigen::RowMajor> m =
90             indices.flat<int32>().maximum();
91         max_index = std::max(m(), max_index);
92       }
93       if (data_elements_size) {
94         *data_elements_size += indices.NumElements();
95       }
96     }
97 
98     *first_dim_size = max_index + 1;
99 
100     // Validate that data[i].shape = indices[i].shape + constant
101     OP_REQUIRES_OK(c, c->input_list("data", data_inputs));
102     const Tensor& data0 = (*data_inputs)[0];
103     const Tensor& indices0 = (*indices_inputs)[0];
104     for (int input_num = 0; input_num < indices_inputs->size(); input_num++) {
105       const Tensor& indices = (*indices_inputs)[input_num];
106       const Tensor& data = (*data_inputs)[input_num];
107       OP_REQUIRES(
108           c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()),
109           errors::InvalidArgument("data[", input_num,
110                                   "].shape = ", data.shape().DebugString(),
111                                   " does not start with indices[", input_num,
112                                   "].shape = ", indices.shape().DebugString()));
113       OP_REQUIRES(
114           c, input_num == 0 || SameExtraShape(data0, indices0, data, indices),
115           errors::InvalidArgument(
116               "Need data[0].shape[", indices0.dims(), ":] = data[", input_num,
117               "].shape[", indices.dims(),
118               ":], got data[0].shape = ", data0.shape().DebugString(),
119               ", data[", input_num, "].shape = ", data.shape().DebugString(),
120               ", indices[0].shape = ", indices0.shape().DebugString(),
121               ", indices[", input_num,
122               "].shape = ", indices.shape().DebugString()));
123     }
124 
125     // Allocate result tensor of shape
126     //   [*first_dim_size] + data.shape[indices.dims:]
127     TensorShape result_shape;
128     result_shape.AddDim(*first_dim_size);
129     for (int d = indices0.dims(); d < data0.dims(); d++) {
130       result_shape.AddDim(data0.dim_size(d));
131     }
132     OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, result_ptr));
133   }
134 };
135 
136 #if GOOGLE_CUDA
137 
138 template <typename T>
139 void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
140                           const int32 slice_size, const int32 first_dim_size,
141                           const GpuDeviceArrayStruct<int>& input_indices,
142                           const GpuDeviceArrayStruct<const T*>& input_ptrs,
143                           T* output);
144 #define REGISTER_GPU(T)                                           \
145   extern template void DynamicStitchGPUImpl(                      \
146       const Eigen::GpuDevice& gpu_device, const int32 slice_size, \
147       const int32 first_dim_size,                                 \
148       const GpuDeviceArrayStruct<int32>& input_indices,           \
149       const GpuDeviceArrayStruct<const T*>& input_ptrs, T* output);
150 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
151 TF_CALL_complex64(REGISTER_GPU);
152 TF_CALL_complex128(REGISTER_GPU);
153 TF_CALL_int64(REGISTER_GPU);
154 TF_CALL_int32(REGISTER_GPU);
155 #undef REGISTER_GPU
156 
157 template <class T>
158 class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> {
159  public:
DynamicStitchOpGPU(OpKernelConstruction * c)160   explicit DynamicStitchOpGPU(OpKernelConstruction* c)
161       : DynamicStitchOpImplBase<T>(c, "DynamicStitchOp") {}
162 
Compute(OpKernelContext * c)163   void Compute(OpKernelContext* c) override {
164     OpInputList indices_inputs;
165     OpInputList data_inputs;
166     int first_dim_size;
167     int data_elements_size;
168     Tensor* merged = nullptr;
169     this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
170                                      &first_dim_size, &data_elements_size,
171                                      &merged);
172     if (!c->status().ok()) {
173       // Avoid segmentation faults if merged cannot be allocated and an error is
174       // passed back in the context.
175       return;
176     }
177 
178     // TODO(jeff): Currently we leave uninitialized any portions of
179     // merged that aren't covered by an index in indices.  What should we do?
180     if (first_dim_size > 0) {
181       // because the collision requirements, we have to deal with
182       // collision first before send data to gpu kernel.
183       // TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the
184       // last of duplicated indices, it could instead be done of the GPU
185       // implicitly using atomics to make sure the last index is the final
186       // write.
187       const int slice_size = merged->flat_outer_dims<T>().dimension(1);
188       GpuDeviceArrayOnHost<int32> indices_flat(c, first_dim_size);
189       GpuDeviceArrayOnHost<const T*> data_flat(c, data_elements_size);
190       OP_REQUIRES_OK(c, indices_flat.Init());
191       OP_REQUIRES_OK(c, data_flat.Init());
192       // initialize the indices_flat (-1 represents missing indices)
193       for (int i = 0; i < first_dim_size; ++i) {
194         indices_flat.Set(i, -1);
195       }
196 
197       // data_flat index
198       int32 idx = 0;
199       // sum of indices_inputs[i].NumElements() for compute indicies_flat value.
200       int32 base_size = 0;
201       for (int i = 0; i < indices_inputs.size(); ++i) {
202         auto indices_vec = indices_inputs[i].flat<int32>();
203         auto data_ptr_base = data_inputs[i].template flat<T>().data();
204         for (int j = 0; j < indices_vec.size(); ++j) {
205           // indices_flat's indices represent the indices of output.
206           // indices_flat's values represent the indices of input_data where the
207           // data located.
208           indices_flat.Set(indices_vec(j), base_size + j);
209           data_flat.Set(
210               idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) +
211                                   j * slice_size));
212           ++idx;
213         }
214         base_size += indices_vec.size();
215       }
216       OP_REQUIRES_OK(c, indices_flat.Finalize());
217       OP_REQUIRES_OK(c, data_flat.Finalize());
218 
219       auto output = merged->template flat<T>().data();
220       DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size,
221                               indices_flat.data(), data_flat.data(), output);
222     }
223   }
224 };
225 
226 #endif  // GOOGLE_CUDA
227 
228 template <class T, bool Parallel>
229 class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
230  public:
DynamicStitchOpImplCPU(OpKernelConstruction * c)231   explicit DynamicStitchOpImplCPU(OpKernelConstruction* c)
232       : DynamicStitchOpImplBase<T>(
233             c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {}
234 
Compute(OpKernelContext * c)235   void Compute(OpKernelContext* c) override {
236     OpInputList indices_inputs;
237     OpInputList data_inputs;
238     int first_dim_size;
239     Tensor* merged = nullptr;
240     this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
241                                      &first_dim_size, nullptr, &merged);
242     if (!c->status().ok()) {
243       // Avoid segmentation faults if merged cannot be allocated and an error is
244       // passed back in the context.
245       return;
246     }
247 
248     // TODO(jeff): Currently we leave uninitialized any portions of
249     // merged that aren't covered by an index in indices.  What should we do?
250     if (first_dim_size > 0) {
251       auto merged_flat = merged->flat_outer_dims<T>();
252       const int slice_size = merged_flat.dimension(1);
253       const size_t slice_bytes = slice_size * sizeof(T);
254       auto OnInputNumber = [&](int input_num) {
255         const Tensor& indices = indices_inputs[input_num];
256         auto indices_vec = indices.flat<int32>();
257         const Tensor& data = data_inputs[input_num];
258         auto data_flat =
259             data.shaped<T, 2>({indices_vec.dimension(0), slice_size});
260 
261         if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
262           T* merged_base = merged_flat.data();
263           const T* data_base = data_flat.data();
264           for (int i = 0; i < indices_vec.size(); i++) {
265             int32 index = internal::SubtleMustCopy(indices_vec(i));
266             OP_REQUIRES(
267                 c, FastBoundsCheck(index, first_dim_size),
268                 errors::InvalidArgument("indices[", i, "] is out of range"));
269             memcpy(merged_base + index * slice_size, data_base + i * slice_size,
270                    slice_bytes);
271           }
272         } else {
273           Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size);
274           for (int i = 0; i < indices_vec.size(); i++) {
275             // Copy slice data[i] to merged[indices[i]]
276             Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0);
277             int32 index = internal::SubtleMustCopy(indices_vec(i));
278             OP_REQUIRES(
279                 c, FastBoundsCheck(index, first_dim_size),
280                 errors::InvalidArgument("indices[", i, "] is out of range"));
281             Eigen::DSizes<Eigen::DenseIndex, 2> merged_indices(index, 0);
282             merged_flat.slice(merged_indices, sizes) =
283                 data_flat.slice(data_indices, sizes);
284           }
285         }
286       };
287       if (Parallel) {
288         auto thread_pool =
289             c->device()->tensorflow_cpu_worker_threads()->workers;
290         size_t total_indices_size = 0;
291         for (int input_num = 0; input_num < indices_inputs.size();
292              ++input_num) {
293           total_indices_size += indices_inputs[input_num].NumElements();
294         }
295         const double avg_indices_size =
296             static_cast<double>(total_indices_size) / indices_inputs.size();
297         auto bytes_processed = slice_bytes * avg_indices_size;
298         auto LoopBody = [&](int first, int last) {
299           for (int input_num = first; input_num < last; ++input_num) {
300             OnInputNumber(input_num);
301           }
302         };
303         thread_pool->ParallelFor(indices_inputs.size(), bytes_processed,
304                                  LoopBody);
305       } else {
306         for (int input_num = 0; input_num < indices_inputs.size();
307              input_num++) {
308           OnInputNumber(input_num);
309         }
310       }
311     }
312   }
313 };
314 
315 // Using inheritance rather than a typedef so that these classes might have more
316 // functionality later.
317 
318 template <typename T>
319 struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> {
320   using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU;
321 };
322 
323 template <typename T>
324 struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> {
325   using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU;
326 };
327 
328 #define REGISTER_DYNAMIC_STITCH(type)                    \
329   REGISTER_KERNEL_BUILDER(Name("DynamicStitch")          \
330                               .Device(DEVICE_CPU)        \
331                               .TypeConstraint<type>("T") \
332                               .HostMemory("indices"),    \
333                           DynamicStitchOpCPU<type>)      \
334   REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch")  \
335                               .Device(DEVICE_CPU)        \
336                               .TypeConstraint<type>("T") \
337                               .HostMemory("indices"),    \
338                           ParallelDynamicStitchOpCPU<type>)
339 
340 TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
341 TF_CALL_variant(REGISTER_DYNAMIC_STITCH);
342 TF_CALL_QUANTIZED_TYPES(REGISTER_DYNAMIC_STITCH);
343 #undef REGISTER_DYNAMIC_STITCH
344 
345 #if GOOGLE_CUDA
346 #define REGISTER_DYNAMIC_STITCH_GPU(type)                \
347   REGISTER_KERNEL_BUILDER(Name("DynamicStitch")          \
348                               .Device(DEVICE_GPU)        \
349                               .TypeConstraint<type>("T") \
350                               .HostMemory("indices"),    \
351                           DynamicStitchOpGPU<type>)      \
352   REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch")  \
353                               .Device(DEVICE_GPU)        \
354                               .TypeConstraint<type>("T") \
355                               .HostMemory("indices")     \
356                               .HostMemory("data")        \
357                               .HostMemory("merged"),     \
358                           ParallelDynamicStitchOpCPU<type>)
359 
360 TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
361 TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU);
362 TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU);
363 TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU);
364 TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU);
365 #undef REGISTER_DYNAMIC_STITCH_GPU
366 
367 #endif  // GOOGLE_CUDA
368 
369 #ifdef TENSORFLOW_USE_SYCL
370 #define REGISTER_DYNAMIC_STITCH_SYCL(type)               \
371   REGISTER_KERNEL_BUILDER(Name("DynamicStitch")          \
372                               .Device(DEVICE_SYCL)       \
373                               .TypeConstraint<type>("T") \
374                               .HostMemory("indices")     \
375                               .HostMemory("data")        \
376                               .HostMemory("merged"),     \
377                           DynamicStitchOpCPU<type>)      \
378   REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch")  \
379                               .Device(DEVICE_SYCL)       \
380                               .TypeConstraint<type>("T") \
381                               .HostMemory("indices")     \
382                               .HostMemory("data")        \
383                               .HostMemory("merged"),     \
384                           ParallelDynamicStitchOpCPU<type>)
385 
386 TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_SYCL);
387 #undef REGISTER_DYNAMIC_STITCH_SYCL
388 #endif  // TENSORFLOW_USE_SYCL
389 }  // namespace tensorflow
390