1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/data_flow_ops.cc. 17 18 #include "tensorflow/core/framework/bounds_check.h" 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/register_types.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/lib/core/threadpool.h" 23 24 #ifdef GOOGLE_CUDA 25 #include "tensorflow/core/kernels/gpu_device_array.h" 26 #endif // GOOGLE_CUDA 27 28 namespace tensorflow { 29 30 typedef Eigen::ThreadPoolDevice CPUDevice; 31 #ifdef GOOGLE_CUDA 32 typedef Eigen::GpuDevice GPUDevice; 33 #endif // GOOGLE_CUDA 34 35 template <class T> 36 class DynamicStitchOpImplBase : public OpKernel { 37 public: DynamicStitchOpImplBase(OpKernelConstruction * c,const string & op_name)38 explicit DynamicStitchOpImplBase(OpKernelConstruction* c, 39 const string& op_name) 40 : OpKernel(c) { 41 // Compute expected input signature 42 const DataType dt = DataTypeToEnum<T>::v(); 43 const int n = c->num_inputs() / 2; 44 DataTypeVector expected; 45 for (int i = 0; i < n; i++) { 46 expected.push_back(DT_INT32); 47 } 48 for (int i = 0; i < n; i++) { 49 expected.push_back(dt); 50 } 51 OP_REQUIRES_OK(c, c->MatchSignature(expected, {dt})); 52 OP_REQUIRES(c, c->num_inputs() > 0, 53 errors::InvalidArgument(op_name + ": Must have some inputs")); 54 OP_REQUIRES(c, c->num_inputs() % 2 == 0, 55 errors::InvalidArgument( 56 op_name + ": Must have even number of arguments")); 57 } 58 59 protected: 60 // Check if data0.shape[indices0.dims():] == data1.shape[indices1.dims():] SameExtraShape(const Tensor & data0,const Tensor & indices0,const Tensor & data1,const Tensor & indices1)61 static bool SameExtraShape(const Tensor& data0, const Tensor& indices0, 62 const Tensor& data1, const Tensor& indices1) { 63 const int extra0 = data0.dims() - indices0.dims(); 64 const int extra1 = data1.dims() - indices1.dims(); 65 if (extra0 != extra1) return false; 66 for (int i = 0; i < extra0; i++) { 67 if (data0.dim_size(indices0.dims() + i) != 68 data1.dim_size(indices1.dims() + i)) { 69 return false; 70 } 71 } 72 return true; 73 } 74 CheckArgsAndAllocateResult(OpKernelContext * c,OpInputList * indices_inputs,OpInputList * data_inputs,int * first_dim_size,int * data_elements_size,Tensor ** result_ptr)75 void CheckArgsAndAllocateResult(OpKernelContext* c, 76 OpInputList* indices_inputs, 77 OpInputList* data_inputs, int* first_dim_size, 78 int* data_elements_size, 79 Tensor** result_ptr) { 80 // Find maximum index in the indices vectors 81 OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs)); 82 83 int32 max_index = -1; 84 if (data_elements_size) { 85 *data_elements_size = 0; 86 } 87 for (const Tensor& indices : *indices_inputs) { 88 if (indices.NumElements() > 0) { 89 Eigen::Tensor<int32, 0, Eigen::RowMajor> m = 90 indices.flat<int32>().maximum(); 91 max_index = std::max(m(), max_index); 92 } 93 if (data_elements_size) { 94 *data_elements_size += indices.NumElements(); 95 } 96 } 97 98 *first_dim_size = max_index + 1; 99 100 // Validate that data[i].shape = indices[i].shape + constant 101 OP_REQUIRES_OK(c, c->input_list("data", data_inputs)); 102 const Tensor& data0 = (*data_inputs)[0]; 103 const Tensor& indices0 = (*indices_inputs)[0]; 104 for (int input_num = 0; input_num < indices_inputs->size(); input_num++) { 105 const Tensor& indices = (*indices_inputs)[input_num]; 106 const Tensor& data = (*data_inputs)[input_num]; 107 OP_REQUIRES( 108 c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()), 109 errors::InvalidArgument("data[", input_num, 110 "].shape = ", data.shape().DebugString(), 111 " does not start with indices[", input_num, 112 "].shape = ", indices.shape().DebugString())); 113 OP_REQUIRES( 114 c, input_num == 0 || SameExtraShape(data0, indices0, data, indices), 115 errors::InvalidArgument( 116 "Need data[0].shape[", indices0.dims(), ":] = data[", input_num, 117 "].shape[", indices.dims(), 118 ":], got data[0].shape = ", data0.shape().DebugString(), 119 ", data[", input_num, "].shape = ", data.shape().DebugString(), 120 ", indices[0].shape = ", indices0.shape().DebugString(), 121 ", indices[", input_num, 122 "].shape = ", indices.shape().DebugString())); 123 } 124 125 // Allocate result tensor of shape 126 // [*first_dim_size] + data.shape[indices.dims:] 127 TensorShape result_shape; 128 result_shape.AddDim(*first_dim_size); 129 for (int d = indices0.dims(); d < data0.dims(); d++) { 130 result_shape.AddDim(data0.dim_size(d)); 131 } 132 OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, result_ptr)); 133 } 134 }; 135 136 #if GOOGLE_CUDA 137 138 template <typename T> 139 void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, 140 const int32 slice_size, const int32 first_dim_size, 141 const GpuDeviceArrayStruct<int>& input_indices, 142 const GpuDeviceArrayStruct<const T*>& input_ptrs, 143 T* output); 144 #define REGISTER_GPU(T) \ 145 extern template void DynamicStitchGPUImpl( \ 146 const Eigen::GpuDevice& gpu_device, const int32 slice_size, \ 147 const int32 first_dim_size, \ 148 const GpuDeviceArrayStruct<int32>& input_indices, \ 149 const GpuDeviceArrayStruct<const T*>& input_ptrs, T* output); 150 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); 151 TF_CALL_complex64(REGISTER_GPU); 152 TF_CALL_complex128(REGISTER_GPU); 153 TF_CALL_int64(REGISTER_GPU); 154 TF_CALL_int32(REGISTER_GPU); 155 #undef REGISTER_GPU 156 157 template <class T> 158 class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> { 159 public: DynamicStitchOpGPU(OpKernelConstruction * c)160 explicit DynamicStitchOpGPU(OpKernelConstruction* c) 161 : DynamicStitchOpImplBase<T>(c, "DynamicStitchOp") {} 162 Compute(OpKernelContext * c)163 void Compute(OpKernelContext* c) override { 164 OpInputList indices_inputs; 165 OpInputList data_inputs; 166 int first_dim_size; 167 int data_elements_size; 168 Tensor* merged = nullptr; 169 this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs, 170 &first_dim_size, &data_elements_size, 171 &merged); 172 if (!c->status().ok()) { 173 // Avoid segmentation faults if merged cannot be allocated and an error is 174 // passed back in the context. 175 return; 176 } 177 178 // TODO(jeff): Currently we leave uninitialized any portions of 179 // merged that aren't covered by an index in indices. What should we do? 180 if (first_dim_size > 0) { 181 // because the collision requirements, we have to deal with 182 // collision first before send data to gpu kernel. 183 // TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the 184 // last of duplicated indices, it could instead be done of the GPU 185 // implicitly using atomics to make sure the last index is the final 186 // write. 187 const int slice_size = merged->flat_outer_dims<T>().dimension(1); 188 GpuDeviceArrayOnHost<int32> indices_flat(c, first_dim_size); 189 GpuDeviceArrayOnHost<const T*> data_flat(c, data_elements_size); 190 OP_REQUIRES_OK(c, indices_flat.Init()); 191 OP_REQUIRES_OK(c, data_flat.Init()); 192 // initialize the indices_flat (-1 represents missing indices) 193 for (int i = 0; i < first_dim_size; ++i) { 194 indices_flat.Set(i, -1); 195 } 196 197 // data_flat index 198 int32 idx = 0; 199 // sum of indices_inputs[i].NumElements() for compute indicies_flat value. 200 int32 base_size = 0; 201 for (int i = 0; i < indices_inputs.size(); ++i) { 202 auto indices_vec = indices_inputs[i].flat<int32>(); 203 auto data_ptr_base = data_inputs[i].template flat<T>().data(); 204 for (int j = 0; j < indices_vec.size(); ++j) { 205 // indices_flat's indices represent the indices of output. 206 // indices_flat's values represent the indices of input_data where the 207 // data located. 208 indices_flat.Set(indices_vec(j), base_size + j); 209 data_flat.Set( 210 idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) + 211 j * slice_size)); 212 ++idx; 213 } 214 base_size += indices_vec.size(); 215 } 216 OP_REQUIRES_OK(c, indices_flat.Finalize()); 217 OP_REQUIRES_OK(c, data_flat.Finalize()); 218 219 auto output = merged->template flat<T>().data(); 220 DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size, 221 indices_flat.data(), data_flat.data(), output); 222 } 223 } 224 }; 225 226 #endif // GOOGLE_CUDA 227 228 template <class T, bool Parallel> 229 class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> { 230 public: DynamicStitchOpImplCPU(OpKernelConstruction * c)231 explicit DynamicStitchOpImplCPU(OpKernelConstruction* c) 232 : DynamicStitchOpImplBase<T>( 233 c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {} 234 Compute(OpKernelContext * c)235 void Compute(OpKernelContext* c) override { 236 OpInputList indices_inputs; 237 OpInputList data_inputs; 238 int first_dim_size; 239 Tensor* merged = nullptr; 240 this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs, 241 &first_dim_size, nullptr, &merged); 242 if (!c->status().ok()) { 243 // Avoid segmentation faults if merged cannot be allocated and an error is 244 // passed back in the context. 245 return; 246 } 247 248 // TODO(jeff): Currently we leave uninitialized any portions of 249 // merged that aren't covered by an index in indices. What should we do? 250 if (first_dim_size > 0) { 251 auto merged_flat = merged->flat_outer_dims<T>(); 252 const int slice_size = merged_flat.dimension(1); 253 const size_t slice_bytes = slice_size * sizeof(T); 254 auto OnInputNumber = [&](int input_num) { 255 const Tensor& indices = indices_inputs[input_num]; 256 auto indices_vec = indices.flat<int32>(); 257 const Tensor& data = data_inputs[input_num]; 258 auto data_flat = 259 data.shaped<T, 2>({indices_vec.dimension(0), slice_size}); 260 261 if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { 262 T* merged_base = merged_flat.data(); 263 const T* data_base = data_flat.data(); 264 for (int i = 0; i < indices_vec.size(); i++) { 265 int32 index = internal::SubtleMustCopy(indices_vec(i)); 266 OP_REQUIRES( 267 c, FastBoundsCheck(index, first_dim_size), 268 errors::InvalidArgument("indices[", i, "] is out of range")); 269 memcpy(merged_base + index * slice_size, data_base + i * slice_size, 270 slice_bytes); 271 } 272 } else { 273 Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size); 274 for (int i = 0; i < indices_vec.size(); i++) { 275 // Copy slice data[i] to merged[indices[i]] 276 Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0); 277 int32 index = internal::SubtleMustCopy(indices_vec(i)); 278 OP_REQUIRES( 279 c, FastBoundsCheck(index, first_dim_size), 280 errors::InvalidArgument("indices[", i, "] is out of range")); 281 Eigen::DSizes<Eigen::DenseIndex, 2> merged_indices(index, 0); 282 merged_flat.slice(merged_indices, sizes) = 283 data_flat.slice(data_indices, sizes); 284 } 285 } 286 }; 287 if (Parallel) { 288 auto thread_pool = 289 c->device()->tensorflow_cpu_worker_threads()->workers; 290 size_t total_indices_size = 0; 291 for (int input_num = 0; input_num < indices_inputs.size(); 292 ++input_num) { 293 total_indices_size += indices_inputs[input_num].NumElements(); 294 } 295 const double avg_indices_size = 296 static_cast<double>(total_indices_size) / indices_inputs.size(); 297 auto bytes_processed = slice_bytes * avg_indices_size; 298 auto LoopBody = [&](int first, int last) { 299 for (int input_num = first; input_num < last; ++input_num) { 300 OnInputNumber(input_num); 301 } 302 }; 303 thread_pool->ParallelFor(indices_inputs.size(), bytes_processed, 304 LoopBody); 305 } else { 306 for (int input_num = 0; input_num < indices_inputs.size(); 307 input_num++) { 308 OnInputNumber(input_num); 309 } 310 } 311 } 312 } 313 }; 314 315 // Using inheritance rather than a typedef so that these classes might have more 316 // functionality later. 317 318 template <typename T> 319 struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> { 320 using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU; 321 }; 322 323 template <typename T> 324 struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> { 325 using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU; 326 }; 327 328 #define REGISTER_DYNAMIC_STITCH(type) \ 329 REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ 330 .Device(DEVICE_CPU) \ 331 .TypeConstraint<type>("T") \ 332 .HostMemory("indices"), \ 333 DynamicStitchOpCPU<type>) \ 334 REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ 335 .Device(DEVICE_CPU) \ 336 .TypeConstraint<type>("T") \ 337 .HostMemory("indices"), \ 338 ParallelDynamicStitchOpCPU<type>) 339 340 TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH); 341 TF_CALL_variant(REGISTER_DYNAMIC_STITCH); 342 TF_CALL_QUANTIZED_TYPES(REGISTER_DYNAMIC_STITCH); 343 #undef REGISTER_DYNAMIC_STITCH 344 345 #if GOOGLE_CUDA 346 #define REGISTER_DYNAMIC_STITCH_GPU(type) \ 347 REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ 348 .Device(DEVICE_GPU) \ 349 .TypeConstraint<type>("T") \ 350 .HostMemory("indices"), \ 351 DynamicStitchOpGPU<type>) \ 352 REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ 353 .Device(DEVICE_GPU) \ 354 .TypeConstraint<type>("T") \ 355 .HostMemory("indices") \ 356 .HostMemory("data") \ 357 .HostMemory("merged"), \ 358 ParallelDynamicStitchOpCPU<type>) 359 360 TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU); 361 TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU); 362 TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU); 363 TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU); 364 TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU); 365 #undef REGISTER_DYNAMIC_STITCH_GPU 366 367 #endif // GOOGLE_CUDA 368 369 #ifdef TENSORFLOW_USE_SYCL 370 #define REGISTER_DYNAMIC_STITCH_SYCL(type) \ 371 REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ 372 .Device(DEVICE_SYCL) \ 373 .TypeConstraint<type>("T") \ 374 .HostMemory("indices") \ 375 .HostMemory("data") \ 376 .HostMemory("merged"), \ 377 DynamicStitchOpCPU<type>) \ 378 REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ 379 .Device(DEVICE_SYCL) \ 380 .TypeConstraint<type>("T") \ 381 .HostMemory("indices") \ 382 .HostMemory("data") \ 383 .HostMemory("merged"), \ 384 ParallelDynamicStitchOpCPU<type>) 385 386 TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_SYCL); 387 #undef REGISTER_DYNAMIC_STITCH_SYCL 388 #endif // TENSORFLOW_USE_SYCL 389 } // namespace tensorflow 390