1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 21 #define EIGEN_USE_GPU 22 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 23 24 #include <numeric> 25 26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 27 #include "tensorflow/core/framework/bounds_check.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/kernels/ops_util.h" 32 #include "tensorflow/core/kernels/split_lib.h" 33 #include "tensorflow/core/lib/core/status.h" 34 #include "tensorflow/core/lib/gtl/array_slice.h" 35 #include "tensorflow/core/util/work_sharder.h" 36 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 37 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 38 #include "tensorflow/core/kernels/gpu_device_array.h" 39 #include "tensorflow/core/kernels/split_lib_gpu.h" 40 #include "tensorflow/core/platform/stream_executor.h" 41 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 42 43 namespace tensorflow { 44 45 typedef Eigen::ThreadPoolDevice CPUDevice; 46 typedef Eigen::GpuDevice GPUDevice; 47 48 template <typename Device, typename T, typename Tlen> 49 class SplitVOpBase : public OpKernel { 50 public: SplitVOpBase(OpKernelConstruction * c)51 explicit SplitVOpBase(OpKernelConstruction* c) : OpKernel(c) {} 52 ComputeEasyCases(OpKernelContext * context,bool * done,std::vector<Tlen> * split_sizes_vec)53 void ComputeEasyCases(OpKernelContext* context, bool* done, 54 std::vector<Tlen>* split_sizes_vec) { 55 const int32_t num_split = context->num_outputs(); 56 const Tensor& input = context->input(0); 57 const TensorShape& input_shape = input.shape(); 58 const Tensor& split_tensor = context->input(1); 59 const Tensor& split_dim_tensor = context->input(2); 60 61 OP_REQUIRES(context, split_dim_tensor.NumElements() == 1, 62 errors::InvalidArgument("split_dim_tensor must have " 63 "exactly one element.")); 64 65 const int32_t split_dim_orig = split_dim_tensor.flat<int32>()(0); 66 const int32_t split_dim = 67 split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; 68 69 OP_REQUIRES( 70 context, 71 split_tensor.dims() == 1 && split_tensor.NumElements() == num_split, 72 errors::InvalidArgument("size of the split_tensor must be 1-D and have " 73 "the same elements as outputs got ", 74 split_tensor.dims(), " -D and ", 75 split_tensor.NumElements(), " elements")); 76 77 auto split_sizes_d = split_tensor.vec<Tlen>(); 78 79 split_sizes_vec->resize(split_sizes_d.size()); 80 81 std::copy(split_sizes_d.data(), split_sizes_d.data() + split_sizes_d.size(), 82 split_sizes_vec->begin()); 83 84 OP_REQUIRES( 85 context, num_split > 0, 86 errors::InvalidArgument( 87 "Number of ways to split should be > 0, but got ", num_split)); 88 89 OP_REQUIRES( 90 context, 0 <= split_dim && split_dim < input.dims(), 91 errors::InvalidArgument("-input rank(-", input.dims(), 92 ") <= split_dim < input rank (", input.dims(), 93 "), but got ", split_dim_orig)); 94 95 Tlen input_size_split_dim = input_shape.dim_size(split_dim); 96 97 // Special case 1: num_split == 1. Nothing to do. 98 if (num_split == 1) { 99 context->set_output(0, context->input(0)); 100 OP_REQUIRES( 101 context, (*split_sizes_vec)[0] == input_size_split_dim, 102 errors::InvalidArgument("If there is only one output, it must have " 103 "the same size as the input. Input size: ", 104 input_size_split_dim, 105 " output size: ", (*split_sizes_vec)[0])); 106 *done = true; 107 return; 108 } 109 110 // Determine sizes of output, in case of a -1 input value 111 int neg_one_dim = -1; 112 Tlen determined_size = 0; 113 for (int d = 0; d < split_sizes_vec->size(); ++d) { 114 Tlen size = (*split_sizes_vec)[d]; 115 116 if (size == -1) { 117 OP_REQUIRES(context, neg_one_dim == -1, 118 errors::InvalidArgument("There can only be one -1 in the " 119 "input.")); 120 neg_one_dim = d; 121 } else { 122 determined_size += size; 123 } 124 } 125 126 OP_REQUIRES( 127 context, 128 (neg_one_dim == -1 && determined_size == input_size_split_dim) || 129 (neg_one_dim >= 0 && determined_size <= input_size_split_dim), 130 errors::InvalidArgument("Determined shape must either match " 131 "input shape along split_dim exactly if " 132 "fully specified, or be less than the size of " 133 "the input along split_dim if not fully " 134 "specified. Got: ", 135 determined_size)); 136 137 if (neg_one_dim >= 0) { 138 (*split_sizes_vec)[neg_one_dim] = input_size_split_dim - determined_size; 139 } 140 141 // Special case 2: split along the 1st dimension. The requirements are that 142 // either we are splitting the outer dimension of two or more such that 143 // every outer subpart is aligned or that the split sizes mean that they are 144 // always aligned. In these cases, we can share the underlying buffer. 145 // 146 // Apply this optimization conservatively: if input is aligned, 147 // the resulting tensors must be aligned. It's conservative 148 // because if the immediate consumer of the resulting tensors are 149 // not using eigen for computation, its perfectly fine to avoid 150 // the copying. 151 if (SplitHasAlignedOutputsInFirstDimension( 152 input_shape, split_dim, absl::MakeConstSpan(*split_sizes_vec))) { 153 Tlen start = 0; 154 for (int i = 0; i < num_split; ++i) { 155 context->set_output(i, 156 input.Slice(start, start + (*split_sizes_vec)[i])); 157 start += (*split_sizes_vec)[i]; 158 } 159 *done = true; 160 return; 161 } 162 } 163 164 template <typename IndexType> SetDims(const TensorShape & input_shape,const int32_t split_dim) const165 std::tuple<IndexType, IndexType, IndexType> SetDims( 166 const TensorShape& input_shape, const int32_t split_dim) const { 167 static_assert(std::is_integral<IndexType>::value, 168 "IndexType must be an integer type"); 169 int32_t prefix_dim_size = 1; 170 for (int i = 0; i < split_dim; ++i) { 171 prefix_dim_size *= input_shape.dim_size(i); 172 } 173 174 // Caller must ensure that dim_size and suffix_dim_size are < 175 // std::numeric_limits<IndexType>::max() 176 IndexType split_dim_size = 177 static_cast<IndexType>(input_shape.dim_size(split_dim)); 178 179 IndexType suffix_dim_size = 1; 180 for (int i = split_dim + 1; i < input_shape.dims(); ++i) { 181 suffix_dim_size *= static_cast<IndexType>(input_shape.dim_size(i)); 182 } 183 return std::make_tuple(prefix_dim_size, split_dim_size, suffix_dim_size); 184 } 185 186 private: 187 // Determines whether the given split configuration can be done using slicing 188 // on the first dimension of the tensor. The requirement is that each result 189 // tensor from the slice is correctly aligned within the input tensor. SplitHasAlignedOutputsInFirstDimension(const TensorShape & input_shape,int32_t split_dim,absl::Span<const Tlen> split_sizes)190 static bool SplitHasAlignedOutputsInFirstDimension( 191 const TensorShape& input_shape, int32_t split_dim, 192 absl::Span<const Tlen> split_sizes) { 193 if (split_dim != 0) { 194 return false; 195 } 196 Tlen start = 0; 197 for (const Tlen split_size : split_sizes) { 198 if (!IsDim0SliceAligned<T>(input_shape, start, start + split_size)) { 199 return false; 200 } 201 start += split_size; 202 } 203 return true; 204 } 205 }; 206 207 template <typename T, typename Tlen, typename InputReshapedType, int NDims> 208 class SplitVOpCPUImpl { 209 public: ParallelSplitByInputData(OpKernelContext * context,const InputReshapedType & input_reshaped,const TensorShape & input_shape,const std::vector<Tlen> & split_sizes_vec,const int32_t split_dim) const210 void ParallelSplitByInputData(OpKernelContext* context, 211 const InputReshapedType& input_reshaped, 212 const TensorShape& input_shape, 213 const std::vector<Tlen>& split_sizes_vec, 214 const int32_t split_dim) const { 215 const T* p_data = input_reshaped.data(); 216 const uint32 elem_pkg = input_reshaped.dimensions().rank() == 3 217 ? input_reshaped.dimension(2) 218 : 1; 219 const uint32 line_elem_num = 220 (input_reshaped.dimensions().rank() >= 2 ? input_reshaped.dimension(1) 221 : 1) * 222 elem_pkg; 223 const uint32 line_num = input_reshaped.dimension(0); 224 225 // Prepare the output matrix. 226 std::vector<T*> outputs(split_sizes_vec.size()); 227 for (uint64 i = 0; i < split_sizes_vec.size(); ++i) { 228 TensorShape output_shape(input_shape); 229 output_shape.set_dim(split_dim, split_sizes_vec[i]); 230 Tensor* result = nullptr; 231 OP_REQUIRES_OK(context, 232 context->allocate_output(i, output_shape, &result)); 233 outputs[i] = static_cast<T*>(&result->flat<T>()(0)); 234 } 235 236 auto sub_split_func = [&split_sizes_vec, &p_data, elem_pkg, &outputs, 237 line_elem_num](int32_t start_part, 238 int32_t end_part) { 239 int start = start_part * line_elem_num; 240 int end = end_part * line_elem_num; 241 uint32 times = 0; 242 for (int32_t i = start; i < end;) { 243 for (uint32 j = 0; j < split_sizes_vec.size(); ++j) { 244 const auto copy_elem_num = split_sizes_vec[j] * elem_pkg; 245 std::copy_n(p_data + i, copy_elem_num, 246 &(outputs[j][(start_part + times) * copy_elem_num])); 247 i += copy_elem_num; 248 } 249 ++times; 250 } 251 }; 252 253 uint32 part_size = 254 context->device()->tensorflow_cpu_worker_threads()->num_threads; 255 Shard(part_size, 256 context->device()->tensorflow_cpu_worker_threads()->workers, line_num, 257 line_num, sub_split_func); 258 } 259 260 template <typename MakeSizesType, typename ReshapeResultType> operator ()(OpKernelContext * context,const InputReshapedType & input_reshaped,const std::vector<int64> & split_start_points,const TensorShape & input_shape,int32_t split_dim,Eigen::DenseIndex prefix_dim_size,Eigen::DenseIndex split_dim_size,Eigen::DenseIndex suffix_dim_size,std::vector<Tlen> & split_sizes_vec,const MakeSizesType & make_sizes,const ReshapeResultType & reshape_result) const261 void operator()(OpKernelContext* context, 262 const InputReshapedType& input_reshaped, 263 const std::vector<int64>& split_start_points, 264 const TensorShape& input_shape, int32_t split_dim, 265 Eigen::DenseIndex prefix_dim_size, 266 Eigen::DenseIndex split_dim_size, 267 Eigen::DenseIndex suffix_dim_size, 268 std::vector<Tlen>& split_sizes_vec, 269 const MakeSizesType& make_sizes, 270 const ReshapeResultType& reshape_result) const { 271 Eigen::DSizes<Eigen::DenseIndex, NDims> indices; 272 for (int i = 0; i < NDims; ++i) { 273 indices[i] = 0; 274 } 275 const auto num_threads = 276 context->device()->tensorflow_cpu_worker_threads()->num_threads; 277 // TODO(jewillco): Tune heuristic further. 278 const auto input_element_count = input_shape.num_elements(); 279 const int num_split = split_start_points.size(); 280 const bool use_parallelism_between_outputs = 281 (num_split >= kMinimumSplitNum && 282 input_element_count >= std::min(num_threads, num_split) * 4096 && 283 input_element_count < num_split * 180 * 1024); 284 285 auto range_output_func = [&indices, context, &input_shape, split_dim, 286 &split_sizes_vec, &split_start_points, 287 use_parallelism_between_outputs, &input_reshaped, 288 &make_sizes, 289 &reshape_result](int64_t start, int64_t limit) { 290 for (int64_t i = start; i < limit; ++i) { 291 TensorShape output_shape(input_shape); 292 output_shape.set_dim(split_dim, split_sizes_vec[i]); 293 Tensor* result = nullptr; 294 OP_REQUIRES_OK(context, 295 context->allocate_output(i, output_shape, &result)); 296 297 const auto sizes = make_sizes(split_sizes_vec[i]); 298 299 if (sizes.TotalSize() > 0) { 300 auto result_shaped = reshape_result(result, split_sizes_vec[i]); 301 302 auto current_indices = indices; 303 current_indices[NDims - 2] = split_start_points[i]; 304 if (use_parallelism_between_outputs) { 305 // Use sequential implementation for single output. 306 result_shaped = input_reshaped.slice(current_indices, sizes); 307 } else { 308 // This implementation may be parallel internally. 309 functor::Split<CPUDevice, T, NDims>()( 310 context->eigen_device<CPUDevice>(), result_shaped, 311 input_reshaped, current_indices, sizes); 312 } 313 } 314 } 315 }; 316 317 // 1. Parallel performance is not as good as serial when the amount of data 318 // is too small (<kMinimumInputSize); 319 // 2. There is sufficient data on the 0th dimension to ensure parallelism; 320 // 3. This method only supports non-zero split. 321 if ((input_element_count >= kMinimumInputSize) && 322 input_reshaped.dimension(0) > kMinimumDim0Size && split_dim) { 323 // Each thread processes the same amount of data, and then copies data 324 // to all output tensors . 325 ParallelSplitByInputData(context, input_reshaped, input_shape, 326 split_sizes_vec, split_dim); 327 } else if (use_parallelism_between_outputs) { 328 // A thread maps a output tensor, this thread will traverse all the data, 329 // and then put specified data to mapped output tensor. Run in parallel, 330 // disabling parallelism in functor. 331 Shard(num_split, 332 context->device()->tensorflow_cpu_worker_threads()->workers, 333 num_split, input_element_count / num_split, range_output_func); 334 } else { 335 // Run sequentially, but allow internal parallelism in functor. 336 range_output_func(0, num_split); 337 } 338 } 339 static constexpr uint64 kMinimumInputSize = 4096 * 512; 340 static constexpr uint64 kMinimumDim0Size = 8; 341 static constexpr uint64 kMinimumSplitNum = 4; 342 }; 343 344 template <typename T, typename Tlen> 345 class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> { 346 public: 347 typedef SplitVOpBase<CPUDevice, T, Tlen> Base; SplitVOpCPU(OpKernelConstruction * c)348 explicit SplitVOpCPU(OpKernelConstruction* c) : Base(c) {} 349 Compute(OpKernelContext * context)350 void Compute(OpKernelContext* context) override { 351 bool done = false; 352 std::vector<Tlen> split_sizes_vec; 353 Base::ComputeEasyCases(context, &done, &split_sizes_vec); 354 if (!context->status().ok() || done) { 355 return; 356 } 357 const int32_t num_split = Base::num_outputs(); 358 const Tensor& input = context->input(0); 359 const TensorShape& input_shape = input.shape(); 360 const int32_t split_dim_orig = context->input(2).flat<int32>()(0); 361 const int32_t split_dim = 362 split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; 363 364 // Android also uses int32 indexing, so check here also. 365 OP_REQUIRES( 366 context, 367 FastBoundsCheck(input.NumElements(), 368 std::numeric_limits<Eigen::DenseIndex>::max()), 369 errors::InvalidArgument("Split requires input size < ", 370 std::numeric_limits<Eigen::DenseIndex>::max())); 371 372 Eigen::DenseIndex prefix_dim_size; 373 Eigen::DenseIndex split_dim_size; 374 Eigen::DenseIndex suffix_dim_size; 375 376 std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) = 377 Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim); 378 std::vector<int64> split_start_points(num_split); 379 for (int i = 0; i < num_split; ++i) { 380 if (i == 0) { 381 split_start_points[i] = 0; 382 } else { 383 split_start_points[i] = 384 split_start_points[i - 1] + split_sizes_vec[i - 1]; 385 } 386 } 387 388 if (prefix_dim_size == 1) { 389 auto input_reshaped = 390 input.shaped<T, 2>({split_dim_size, suffix_dim_size}); 391 auto make_sizes = [&](Eigen::DenseIndex split_size) { 392 return Eigen::DSizes<Eigen::DenseIndex, 2>{split_size, suffix_dim_size}; 393 }; 394 auto reshape_result = [&](Tensor* result, Tlen split_size) { 395 return result->shaped<T, 2>({split_size, suffix_dim_size}); 396 }; 397 SplitVOpCPUImpl<T, Tlen, decltype(input_reshaped), 2>{}( 398 context, input_reshaped, split_start_points, input_shape, split_dim, 399 prefix_dim_size, split_dim_size, suffix_dim_size, split_sizes_vec, 400 make_sizes, reshape_result); 401 } else { 402 auto input_reshaped = input.shaped<T, 3>( 403 {prefix_dim_size, split_dim_size, suffix_dim_size}); 404 auto make_sizes = [&](Eigen::DenseIndex split_size) { 405 return Eigen::DSizes<Eigen::DenseIndex, 3>{prefix_dim_size, split_size, 406 suffix_dim_size}; 407 }; 408 auto reshape_result = [&](Tensor* result, Tlen split_size) { 409 return result->shaped<T, 3>( 410 {prefix_dim_size, split_size, suffix_dim_size}); 411 }; 412 SplitVOpCPUImpl<T, Tlen, decltype(input_reshaped), 3>{}( 413 context, input_reshaped, split_start_points, input_shape, split_dim, 414 prefix_dim_size, split_dim_size, suffix_dim_size, split_sizes_vec, 415 make_sizes, reshape_result); 416 } 417 } 418 }; 419 420 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 421 422 // Partial specialization for GPU 423 template <typename T, typename Tlen> 424 class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> { 425 public: 426 typedef SplitVOpBase<GPUDevice, T, Tlen> Base; SplitVOpGPU(OpKernelConstruction * c)427 explicit SplitVOpGPU(OpKernelConstruction* c) : Base(c) {} 428 Compute(OpKernelContext * context)429 void Compute(OpKernelContext* context) override { 430 bool done = false; 431 std::vector<Tlen> split_sizes_vec; 432 Base::ComputeEasyCases(context, &done, &split_sizes_vec); 433 if (!context->status().ok() || done) { 434 return; 435 } 436 const int32_t num_split = Base::num_outputs(); 437 const Tensor& input = context->input(0); 438 const TensorShape& input_shape = input.shape(); 439 const int32_t split_dim_orig = context->input(2).flat<int32>()(0); 440 const int32_t split_dim = 441 split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; 442 OP_REQUIRES( 443 context, 444 FastBoundsCheck(input.NumElements(), std::numeric_limits<int32>::max()), 445 errors::InvalidArgument("Split on GPU requires input size " 446 "< max int32")); 447 448 int32_t prefix_dim_size; 449 int32_t split_dim_size; 450 int32_t suffix_dim_size; 451 std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) = 452 Base::template SetDims<int32>(input_shape, split_dim); 453 454 // use the same approach as concat (see documentation there) 455 // reshape to 2D 456 457 if (num_split > 16) { 458 GpuDeviceArrayOnHost<T*> ptrs(context, num_split); 459 OP_REQUIRES_OK(context, ptrs.Init()); 460 461 GpuDeviceArrayOnHost<Tlen> offsets(context, num_split + 1); 462 OP_REQUIRES_OK(context, offsets.Init()); 463 464 Tlen offset = 0; 465 int entry = split_sizes_vec[0]; 466 bool fixed_size = 467 std::all_of(split_sizes_vec.begin(), split_sizes_vec.end(), 468 [&entry](int n) { return n == entry; }); 469 470 for (int i = 0; i < num_split; ++i) { 471 TensorShape output_shape(input_shape); 472 output_shape.set_dim(split_dim, split_sizes_vec[i]); 473 Tensor* result = nullptr; 474 OP_REQUIRES_OK(context, 475 context->allocate_output(i, output_shape, &result)); 476 ptrs.Set(i, result->flat<T>().data()); 477 offsets.Set(i, offset); 478 offset += split_sizes_vec[i] * suffix_dim_size; 479 } 480 offsets.Set(num_split, offset); 481 OP_REQUIRES_OK(context, ptrs.Finalize()); 482 OP_REQUIRES_OK(context, offsets.Finalize()); 483 484 if (input.NumElements() > 0) { 485 SplitVOpGPULaunch<T, Tlen>().Run( 486 context->eigen_device<GPUDevice>(), fixed_size, 487 input.flat<T>().data(), prefix_dim_size, 488 input.NumElements() / prefix_dim_size, offsets.data(), ptrs.data()); 489 OP_REQUIRES( 490 context, context->op_device_context()->stream()->ok(), 491 errors::Internal("Launch of gpu kernel for SplitVOp failed")); 492 } 493 } else { 494 Eigen::DenseIndex prefix_dim_size; 495 Eigen::DenseIndex split_dim_size; 496 Eigen::DenseIndex suffix_dim_size; 497 498 std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) = 499 Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim); 500 auto input_reshaped = input.shaped<T, 2>( 501 {prefix_dim_size, split_dim_size * suffix_dim_size}); 502 503 Eigen::DSizes<Eigen::DenseIndex, 2> indices{0, 0}; 504 505 for (int i = 0; i < num_split; ++i) { 506 TensorShape output_shape(input_shape); 507 output_shape.set_dim(split_dim, split_sizes_vec[i]); 508 Tensor* result = nullptr; 509 OP_REQUIRES_OK(context, 510 context->allocate_output(i, output_shape, &result)); 511 512 Eigen::DSizes<Eigen::DenseIndex, 2> sizes{ 513 prefix_dim_size, split_sizes_vec[i] * suffix_dim_size}; 514 515 if (sizes.TotalSize() > 0) { 516 auto result_shaped = result->shaped<T, 2>( 517 {prefix_dim_size, split_sizes_vec[i] * suffix_dim_size}); 518 519 functor::SplitCustom<GPUDevice, T>()( 520 context->eigen_device<GPUDevice>(), result_shaped, input_reshaped, 521 indices, sizes); 522 } 523 indices[1] += split_sizes_vec[i] * suffix_dim_size; 524 } 525 } 526 } 527 }; 528 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 529 530 #define REGISTER_SPLIT(type, len_type) \ 531 REGISTER_KERNEL_BUILDER(Name("SplitV") \ 532 .Device(DEVICE_CPU) \ 533 .TypeConstraint<len_type>("Tlen") \ 534 .TypeConstraint<type>("T") \ 535 .HostMemory("size_splits") \ 536 .HostMemory("split_dim"), \ 537 SplitVOpCPU<type, len_type>); 538 539 #define REGISTER_SPLIT_LEN(type) \ 540 REGISTER_SPLIT(type, int32); \ 541 REGISTER_SPLIT(type, int64); 542 543 TF_CALL_ALL_TYPES(REGISTER_SPLIT_LEN); 544 545 #undef REGISTER_SPLIT_LEN 546 #undef REGISTER_SPLIT 547 548 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 549 550 #define REGISTER_GPU(type, len_type) \ 551 REGISTER_KERNEL_BUILDER(Name("SplitV") \ 552 .Device(DEVICE_GPU) \ 553 .TypeConstraint<len_type>("Tlen") \ 554 .TypeConstraint<type>("T") \ 555 .HostMemory("size_splits") \ 556 .HostMemory("split_dim"), \ 557 SplitVOpGPU<type, len_type>); 558 559 #define REGISTER_GPU_LEN(type) \ 560 REGISTER_GPU(type, int32); \ 561 REGISTER_GPU(type, int64); 562 563 TF_CALL_bfloat16(REGISTER_GPU_LEN); 564 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_LEN); 565 TF_CALL_COMPLEX_TYPES(REGISTER_GPU_LEN); 566 #undef REGISTER_GPU_LEN 567 #undef REGISTER_GPU 568 569 // special GPU kernel for int32 570 571 #define REGISTER_GPU_int32(len_type) \ 572 REGISTER_KERNEL_BUILDER(Name("SplitV") \ 573 .Device(DEVICE_GPU) \ 574 .TypeConstraint<int32>("T") \ 575 .TypeConstraint<len_type>("Tlen") \ 576 .HostMemory("size_splits") \ 577 .HostMemory("split_dim") \ 578 .HostMemory("value") \ 579 .HostMemory("output"), \ 580 SplitVOpCPU<int32, len_type>); 581 582 REGISTER_GPU_int32(int32); 583 REGISTER_GPU_int32(int64); 584 585 #undef REGISTER_GPU_int32 586 587 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 588 589 } // end namespace tensorflow 590