• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 #define EIGEN_USE_GPU
22 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 
24 #include <numeric>
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/bounds_check.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/kernels/ops_util.h"
32 #include "tensorflow/core/kernels/split_lib.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/util/work_sharder.h"
36 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
38 #include "tensorflow/core/kernels/gpu_device_array.h"
39 #include "tensorflow/core/kernels/split_lib_gpu.h"
40 #include "tensorflow/core/platform/stream_executor.h"
41 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
42 
43 namespace tensorflow {
44 
45 typedef Eigen::ThreadPoolDevice CPUDevice;
46 typedef Eigen::GpuDevice GPUDevice;
47 
48 template <typename Device, typename T, typename Tlen>
49 class SplitVOpBase : public OpKernel {
50  public:
SplitVOpBase(OpKernelConstruction * c)51   explicit SplitVOpBase(OpKernelConstruction* c) : OpKernel(c) {}
52 
ComputeEasyCases(OpKernelContext * context,bool * done,std::vector<Tlen> * split_sizes_vec)53   void ComputeEasyCases(OpKernelContext* context, bool* done,
54                         std::vector<Tlen>* split_sizes_vec) {
55     const int32_t num_split = context->num_outputs();
56     const Tensor& input = context->input(0);
57     const TensorShape& input_shape = input.shape();
58     const Tensor& split_tensor = context->input(1);
59     const Tensor& split_dim_tensor = context->input(2);
60 
61     OP_REQUIRES(context, split_dim_tensor.NumElements() == 1,
62                 errors::InvalidArgument("split_dim_tensor must have "
63                                         "exactly one element."));
64 
65     const int32_t split_dim_orig = split_dim_tensor.flat<int32>()(0);
66     const int32_t split_dim =
67         split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
68 
69     OP_REQUIRES(
70         context,
71         split_tensor.dims() == 1 && split_tensor.NumElements() == num_split,
72         errors::InvalidArgument("size of the split_tensor must be 1-D and have "
73                                 "the same elements as outputs got ",
74                                 split_tensor.dims(), " -D and ",
75                                 split_tensor.NumElements(), " elements"));
76 
77     auto split_sizes_d = split_tensor.vec<Tlen>();
78 
79     split_sizes_vec->resize(split_sizes_d.size());
80 
81     std::copy(split_sizes_d.data(), split_sizes_d.data() + split_sizes_d.size(),
82               split_sizes_vec->begin());
83 
84     OP_REQUIRES(
85         context, num_split > 0,
86         errors::InvalidArgument(
87             "Number of ways to split should be > 0, but got ", num_split));
88 
89     OP_REQUIRES(
90         context, 0 <= split_dim && split_dim < input.dims(),
91         errors::InvalidArgument("-input rank(-", input.dims(),
92                                 ") <= split_dim < input rank (", input.dims(),
93                                 "), but got ", split_dim_orig));
94 
95     Tlen input_size_split_dim = input_shape.dim_size(split_dim);
96 
97     // Special case 1: num_split == 1. Nothing to do.
98     if (num_split == 1) {
99       context->set_output(0, context->input(0));
100       OP_REQUIRES(
101           context, (*split_sizes_vec)[0] == input_size_split_dim,
102           errors::InvalidArgument("If there is only one output, it must have "
103                                   "the same size as the input. Input size: ",
104                                   input_size_split_dim,
105                                   " output size: ", (*split_sizes_vec)[0]));
106       *done = true;
107       return;
108     }
109 
110     // Determine sizes of output, in case of a -1 input value
111     int neg_one_dim = -1;
112     Tlen determined_size = 0;
113     for (int d = 0; d < split_sizes_vec->size(); ++d) {
114       Tlen size = (*split_sizes_vec)[d];
115 
116       if (size == -1) {
117         OP_REQUIRES(context, neg_one_dim == -1,
118                     errors::InvalidArgument("There can only be one -1 in the "
119                                             "input."));
120         neg_one_dim = d;
121       } else {
122         determined_size += size;
123       }
124     }
125 
126     OP_REQUIRES(
127         context,
128         (neg_one_dim == -1 && determined_size == input_size_split_dim) ||
129             (neg_one_dim >= 0 && determined_size <= input_size_split_dim),
130         errors::InvalidArgument("Determined shape must either match "
131                                 "input shape along split_dim exactly if "
132                                 "fully specified, or be less than the size of "
133                                 "the input along split_dim if not fully "
134                                 "specified.  Got: ",
135                                 determined_size));
136 
137     if (neg_one_dim >= 0) {
138       (*split_sizes_vec)[neg_one_dim] = input_size_split_dim - determined_size;
139     }
140 
141     // Special case 2: split along the 1st dimension. The requirements are that
142     // either we are splitting the outer dimension of two or more such that
143     // every outer subpart is aligned or that the split sizes mean that they are
144     // always aligned. In these cases, we can share the underlying buffer.
145     //
146     // Apply this optimization conservatively: if input is aligned,
147     // the resulting tensors must be aligned. It's conservative
148     // because if the immediate consumer of the resulting tensors are
149     // not using eigen for computation, its perfectly fine to avoid
150     // the copying.
151     if (SplitHasAlignedOutputsInFirstDimension(
152             input_shape, split_dim, absl::MakeConstSpan(*split_sizes_vec))) {
153       Tlen start = 0;
154       for (int i = 0; i < num_split; ++i) {
155         context->set_output(i,
156                             input.Slice(start, start + (*split_sizes_vec)[i]));
157         start += (*split_sizes_vec)[i];
158       }
159       *done = true;
160       return;
161     }
162   }
163 
164   template <typename IndexType>
SetDims(const TensorShape & input_shape,const int32_t split_dim) const165   std::tuple<IndexType, IndexType, IndexType> SetDims(
166       const TensorShape& input_shape, const int32_t split_dim) const {
167     static_assert(std::is_integral<IndexType>::value,
168                   "IndexType must be an integer type");
169     int32_t prefix_dim_size = 1;
170     for (int i = 0; i < split_dim; ++i) {
171       prefix_dim_size *= input_shape.dim_size(i);
172     }
173 
174     // Caller must ensure that dim_size and suffix_dim_size are <
175     // std::numeric_limits<IndexType>::max()
176     IndexType split_dim_size =
177         static_cast<IndexType>(input_shape.dim_size(split_dim));
178 
179     IndexType suffix_dim_size = 1;
180     for (int i = split_dim + 1; i < input_shape.dims(); ++i) {
181       suffix_dim_size *= static_cast<IndexType>(input_shape.dim_size(i));
182     }
183     return std::make_tuple(prefix_dim_size, split_dim_size, suffix_dim_size);
184   }
185 
186  private:
187   // Determines whether the given split configuration can be done using slicing
188   // on the first dimension of the tensor. The requirement is that each result
189   // tensor from the slice is correctly aligned within the input tensor.
SplitHasAlignedOutputsInFirstDimension(const TensorShape & input_shape,int32_t split_dim,absl::Span<const Tlen> split_sizes)190   static bool SplitHasAlignedOutputsInFirstDimension(
191       const TensorShape& input_shape, int32_t split_dim,
192       absl::Span<const Tlen> split_sizes) {
193     if (split_dim != 0) {
194       return false;
195     }
196     Tlen start = 0;
197     for (const Tlen split_size : split_sizes) {
198       if (!IsDim0SliceAligned<T>(input_shape, start, start + split_size)) {
199         return false;
200       }
201       start += split_size;
202     }
203     return true;
204   }
205 };
206 
207 template <typename T, typename Tlen, typename InputReshapedType, int NDims>
208 class SplitVOpCPUImpl {
209  public:
ParallelSplitByInputData(OpKernelContext * context,const InputReshapedType & input_reshaped,const TensorShape & input_shape,const std::vector<Tlen> & split_sizes_vec,const int32_t split_dim) const210   void ParallelSplitByInputData(OpKernelContext* context,
211                                 const InputReshapedType& input_reshaped,
212                                 const TensorShape& input_shape,
213                                 const std::vector<Tlen>& split_sizes_vec,
214                                 const int32_t split_dim) const {
215     const T* p_data = input_reshaped.data();
216     const uint32 elem_pkg = input_reshaped.dimensions().rank() == 3
217                                 ? input_reshaped.dimension(2)
218                                 : 1;
219     const uint32 line_elem_num =
220         (input_reshaped.dimensions().rank() >= 2 ? input_reshaped.dimension(1)
221                                                  : 1) *
222         elem_pkg;
223     const uint32 line_num = input_reshaped.dimension(0);
224 
225     // Prepare the output matrix.
226     std::vector<T*> outputs(split_sizes_vec.size());
227     for (uint64 i = 0; i < split_sizes_vec.size(); ++i) {
228       TensorShape output_shape(input_shape);
229       output_shape.set_dim(split_dim, split_sizes_vec[i]);
230       Tensor* result = nullptr;
231       OP_REQUIRES_OK(context,
232                      context->allocate_output(i, output_shape, &result));
233       outputs[i] = static_cast<T*>(&result->flat<T>()(0));
234     }
235 
236     auto sub_split_func = [&split_sizes_vec, &p_data, elem_pkg, &outputs,
237                            line_elem_num](int32_t start_part,
238                                           int32_t end_part) {
239       int start = start_part * line_elem_num;
240       int end = end_part * line_elem_num;
241       uint32 times = 0;
242       for (int32_t i = start; i < end;) {
243         for (uint32 j = 0; j < split_sizes_vec.size(); ++j) {
244           const auto copy_elem_num = split_sizes_vec[j] * elem_pkg;
245           std::copy_n(p_data + i, copy_elem_num,
246                       &(outputs[j][(start_part + times) * copy_elem_num]));
247           i += copy_elem_num;
248         }
249         ++times;
250       }
251     };
252 
253     uint32 part_size =
254         context->device()->tensorflow_cpu_worker_threads()->num_threads;
255     Shard(part_size,
256           context->device()->tensorflow_cpu_worker_threads()->workers, line_num,
257           line_num, sub_split_func);
258   }
259 
260   template <typename MakeSizesType, typename ReshapeResultType>
operator ()(OpKernelContext * context,const InputReshapedType & input_reshaped,const std::vector<int64> & split_start_points,const TensorShape & input_shape,int32_t split_dim,Eigen::DenseIndex prefix_dim_size,Eigen::DenseIndex split_dim_size,Eigen::DenseIndex suffix_dim_size,std::vector<Tlen> & split_sizes_vec,const MakeSizesType & make_sizes,const ReshapeResultType & reshape_result) const261   void operator()(OpKernelContext* context,
262                   const InputReshapedType& input_reshaped,
263                   const std::vector<int64>& split_start_points,
264                   const TensorShape& input_shape, int32_t split_dim,
265                   Eigen::DenseIndex prefix_dim_size,
266                   Eigen::DenseIndex split_dim_size,
267                   Eigen::DenseIndex suffix_dim_size,
268                   std::vector<Tlen>& split_sizes_vec,
269                   const MakeSizesType& make_sizes,
270                   const ReshapeResultType& reshape_result) const {
271     Eigen::DSizes<Eigen::DenseIndex, NDims> indices;
272     for (int i = 0; i < NDims; ++i) {
273       indices[i] = 0;
274     }
275     const auto num_threads =
276         context->device()->tensorflow_cpu_worker_threads()->num_threads;
277     // TODO(jewillco): Tune heuristic further.
278     const auto input_element_count = input_shape.num_elements();
279     const int num_split = split_start_points.size();
280     const bool use_parallelism_between_outputs =
281         (num_split >= kMinimumSplitNum &&
282          input_element_count >= std::min(num_threads, num_split) * 4096 &&
283          input_element_count < num_split * 180 * 1024);
284 
285     auto range_output_func = [&indices, context, &input_shape, split_dim,
286                               &split_sizes_vec, &split_start_points,
287                               use_parallelism_between_outputs, &input_reshaped,
288                               &make_sizes,
289                               &reshape_result](int64_t start, int64_t limit) {
290       for (int64_t i = start; i < limit; ++i) {
291         TensorShape output_shape(input_shape);
292         output_shape.set_dim(split_dim, split_sizes_vec[i]);
293         Tensor* result = nullptr;
294         OP_REQUIRES_OK(context,
295                        context->allocate_output(i, output_shape, &result));
296 
297         const auto sizes = make_sizes(split_sizes_vec[i]);
298 
299         if (sizes.TotalSize() > 0) {
300           auto result_shaped = reshape_result(result, split_sizes_vec[i]);
301 
302           auto current_indices = indices;
303           current_indices[NDims - 2] = split_start_points[i];
304           if (use_parallelism_between_outputs) {
305             // Use sequential implementation for single output.
306             result_shaped = input_reshaped.slice(current_indices, sizes);
307           } else {
308             // This implementation may be parallel internally.
309             functor::Split<CPUDevice, T, NDims>()(
310                 context->eigen_device<CPUDevice>(), result_shaped,
311                 input_reshaped, current_indices, sizes);
312           }
313         }
314       }
315     };
316 
317     // 1. Parallel performance is not as good as serial when the amount of data
318     // is too small (<kMinimumInputSize);
319     // 2. There is sufficient data on the 0th dimension to ensure parallelism;
320     // 3. This method only supports non-zero split.
321     if ((input_element_count >= kMinimumInputSize) &&
322         input_reshaped.dimension(0) > kMinimumDim0Size && split_dim) {
323       // Each thread processes the same amount of data, and then copies data
324       // to all output tensors .
325       ParallelSplitByInputData(context, input_reshaped, input_shape,
326                                split_sizes_vec, split_dim);
327     } else if (use_parallelism_between_outputs) {
328       // A thread maps a output tensor, this thread will traverse all the data,
329       // and then put specified data to mapped output tensor. Run in parallel,
330       // disabling parallelism in functor.
331       Shard(num_split,
332             context->device()->tensorflow_cpu_worker_threads()->workers,
333             num_split, input_element_count / num_split, range_output_func);
334     } else {
335       // Run sequentially, but allow internal parallelism in functor.
336       range_output_func(0, num_split);
337     }
338   }
339   static constexpr uint64 kMinimumInputSize = 4096 * 512;
340   static constexpr uint64 kMinimumDim0Size = 8;
341   static constexpr uint64 kMinimumSplitNum = 4;
342 };
343 
344 template <typename T, typename Tlen>
345 class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
346  public:
347   typedef SplitVOpBase<CPUDevice, T, Tlen> Base;
SplitVOpCPU(OpKernelConstruction * c)348   explicit SplitVOpCPU(OpKernelConstruction* c) : Base(c) {}
349 
Compute(OpKernelContext * context)350   void Compute(OpKernelContext* context) override {
351     bool done = false;
352     std::vector<Tlen> split_sizes_vec;
353     Base::ComputeEasyCases(context, &done, &split_sizes_vec);
354     if (!context->status().ok() || done) {
355       return;
356     }
357     const int32_t num_split = Base::num_outputs();
358     const Tensor& input = context->input(0);
359     const TensorShape& input_shape = input.shape();
360     const int32_t split_dim_orig = context->input(2).flat<int32>()(0);
361     const int32_t split_dim =
362         split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
363 
364     // Android also uses int32 indexing, so check here also.
365     OP_REQUIRES(
366         context,
367         FastBoundsCheck(input.NumElements(),
368                         std::numeric_limits<Eigen::DenseIndex>::max()),
369         errors::InvalidArgument("Split requires input size < ",
370                                 std::numeric_limits<Eigen::DenseIndex>::max()));
371 
372     Eigen::DenseIndex prefix_dim_size;
373     Eigen::DenseIndex split_dim_size;
374     Eigen::DenseIndex suffix_dim_size;
375 
376     std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
377         Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
378     std::vector<int64> split_start_points(num_split);
379     for (int i = 0; i < num_split; ++i) {
380       if (i == 0) {
381         split_start_points[i] = 0;
382       } else {
383         split_start_points[i] =
384             split_start_points[i - 1] + split_sizes_vec[i - 1];
385       }
386     }
387 
388     if (prefix_dim_size == 1) {
389       auto input_reshaped =
390           input.shaped<T, 2>({split_dim_size, suffix_dim_size});
391       auto make_sizes = [&](Eigen::DenseIndex split_size) {
392         return Eigen::DSizes<Eigen::DenseIndex, 2>{split_size, suffix_dim_size};
393       };
394       auto reshape_result = [&](Tensor* result, Tlen split_size) {
395         return result->shaped<T, 2>({split_size, suffix_dim_size});
396       };
397       SplitVOpCPUImpl<T, Tlen, decltype(input_reshaped), 2>{}(
398           context, input_reshaped, split_start_points, input_shape, split_dim,
399           prefix_dim_size, split_dim_size, suffix_dim_size, split_sizes_vec,
400           make_sizes, reshape_result);
401     } else {
402       auto input_reshaped = input.shaped<T, 3>(
403           {prefix_dim_size, split_dim_size, suffix_dim_size});
404       auto make_sizes = [&](Eigen::DenseIndex split_size) {
405         return Eigen::DSizes<Eigen::DenseIndex, 3>{prefix_dim_size, split_size,
406                                                    suffix_dim_size};
407       };
408       auto reshape_result = [&](Tensor* result, Tlen split_size) {
409         return result->shaped<T, 3>(
410             {prefix_dim_size, split_size, suffix_dim_size});
411       };
412       SplitVOpCPUImpl<T, Tlen, decltype(input_reshaped), 3>{}(
413           context, input_reshaped, split_start_points, input_shape, split_dim,
414           prefix_dim_size, split_dim_size, suffix_dim_size, split_sizes_vec,
415           make_sizes, reshape_result);
416     }
417   }
418 };
419 
420 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
421 
422 // Partial specialization for GPU
423 template <typename T, typename Tlen>
424 class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
425  public:
426   typedef SplitVOpBase<GPUDevice, T, Tlen> Base;
SplitVOpGPU(OpKernelConstruction * c)427   explicit SplitVOpGPU(OpKernelConstruction* c) : Base(c) {}
428 
Compute(OpKernelContext * context)429   void Compute(OpKernelContext* context) override {
430     bool done = false;
431     std::vector<Tlen> split_sizes_vec;
432     Base::ComputeEasyCases(context, &done, &split_sizes_vec);
433     if (!context->status().ok() || done) {
434       return;
435     }
436     const int32_t num_split = Base::num_outputs();
437     const Tensor& input = context->input(0);
438     const TensorShape& input_shape = input.shape();
439     const int32_t split_dim_orig = context->input(2).flat<int32>()(0);
440     const int32_t split_dim =
441         split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
442     OP_REQUIRES(
443         context,
444         FastBoundsCheck(input.NumElements(), std::numeric_limits<int32>::max()),
445         errors::InvalidArgument("Split on GPU requires input size "
446                                 "< max int32"));
447 
448     int32_t prefix_dim_size;
449     int32_t split_dim_size;
450     int32_t suffix_dim_size;
451     std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
452         Base::template SetDims<int32>(input_shape, split_dim);
453 
454     // use the same approach as concat (see documentation there)
455     // reshape to 2D
456 
457     if (num_split > 16) {
458       GpuDeviceArrayOnHost<T*> ptrs(context, num_split);
459       OP_REQUIRES_OK(context, ptrs.Init());
460 
461       GpuDeviceArrayOnHost<Tlen> offsets(context, num_split + 1);
462       OP_REQUIRES_OK(context, offsets.Init());
463 
464       Tlen offset = 0;
465       int entry = split_sizes_vec[0];
466       bool fixed_size =
467           std::all_of(split_sizes_vec.begin(), split_sizes_vec.end(),
468                       [&entry](int n) { return n == entry; });
469 
470       for (int i = 0; i < num_split; ++i) {
471         TensorShape output_shape(input_shape);
472         output_shape.set_dim(split_dim, split_sizes_vec[i]);
473         Tensor* result = nullptr;
474         OP_REQUIRES_OK(context,
475                        context->allocate_output(i, output_shape, &result));
476         ptrs.Set(i, result->flat<T>().data());
477         offsets.Set(i, offset);
478         offset += split_sizes_vec[i] * suffix_dim_size;
479       }
480       offsets.Set(num_split, offset);
481       OP_REQUIRES_OK(context, ptrs.Finalize());
482       OP_REQUIRES_OK(context, offsets.Finalize());
483 
484       if (input.NumElements() > 0) {
485         SplitVOpGPULaunch<T, Tlen>().Run(
486             context->eigen_device<GPUDevice>(), fixed_size,
487             input.flat<T>().data(), prefix_dim_size,
488             input.NumElements() / prefix_dim_size, offsets.data(), ptrs.data());
489         OP_REQUIRES(
490             context, context->op_device_context()->stream()->ok(),
491             errors::Internal("Launch of gpu kernel for SplitVOp failed"));
492       }
493     } else {
494       Eigen::DenseIndex prefix_dim_size;
495       Eigen::DenseIndex split_dim_size;
496       Eigen::DenseIndex suffix_dim_size;
497 
498       std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
499           Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
500       auto input_reshaped = input.shaped<T, 2>(
501           {prefix_dim_size, split_dim_size * suffix_dim_size});
502 
503       Eigen::DSizes<Eigen::DenseIndex, 2> indices{0, 0};
504 
505       for (int i = 0; i < num_split; ++i) {
506         TensorShape output_shape(input_shape);
507         output_shape.set_dim(split_dim, split_sizes_vec[i]);
508         Tensor* result = nullptr;
509         OP_REQUIRES_OK(context,
510                        context->allocate_output(i, output_shape, &result));
511 
512         Eigen::DSizes<Eigen::DenseIndex, 2> sizes{
513             prefix_dim_size, split_sizes_vec[i] * suffix_dim_size};
514 
515         if (sizes.TotalSize() > 0) {
516           auto result_shaped = result->shaped<T, 2>(
517               {prefix_dim_size, split_sizes_vec[i] * suffix_dim_size});
518 
519           functor::SplitCustom<GPUDevice, T>()(
520               context->eigen_device<GPUDevice>(), result_shaped, input_reshaped,
521               indices, sizes);
522         }
523         indices[1] += split_sizes_vec[i] * suffix_dim_size;
524       }
525     }
526   }
527 };
528 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
529 
530 #define REGISTER_SPLIT(type, len_type)                          \
531   REGISTER_KERNEL_BUILDER(Name("SplitV")                        \
532                               .Device(DEVICE_CPU)               \
533                               .TypeConstraint<len_type>("Tlen") \
534                               .TypeConstraint<type>("T")        \
535                               .HostMemory("size_splits")        \
536                               .HostMemory("split_dim"),         \
537                           SplitVOpCPU<type, len_type>);
538 
539 #define REGISTER_SPLIT_LEN(type) \
540   REGISTER_SPLIT(type, int32);   \
541   REGISTER_SPLIT(type, int64);
542 
543 TF_CALL_ALL_TYPES(REGISTER_SPLIT_LEN);
544 
545 #undef REGISTER_SPLIT_LEN
546 #undef REGISTER_SPLIT
547 
548 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
549 
550 #define REGISTER_GPU(type, len_type)                            \
551   REGISTER_KERNEL_BUILDER(Name("SplitV")                        \
552                               .Device(DEVICE_GPU)               \
553                               .TypeConstraint<len_type>("Tlen") \
554                               .TypeConstraint<type>("T")        \
555                               .HostMemory("size_splits")        \
556                               .HostMemory("split_dim"),         \
557                           SplitVOpGPU<type, len_type>);
558 
559 #define REGISTER_GPU_LEN(type) \
560   REGISTER_GPU(type, int32);   \
561   REGISTER_GPU(type, int64);
562 
563 TF_CALL_bfloat16(REGISTER_GPU_LEN);
564 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_LEN);
565 TF_CALL_COMPLEX_TYPES(REGISTER_GPU_LEN);
566 #undef REGISTER_GPU_LEN
567 #undef REGISTER_GPU
568 
569 // special GPU kernel for int32
570 
571 #define REGISTER_GPU_int32(len_type)                            \
572   REGISTER_KERNEL_BUILDER(Name("SplitV")                        \
573                               .Device(DEVICE_GPU)               \
574                               .TypeConstraint<int32>("T")       \
575                               .TypeConstraint<len_type>("Tlen") \
576                               .HostMemory("size_splits")        \
577                               .HostMemory("split_dim")          \
578                               .HostMemory("value")              \
579                               .HostMemory("output"),            \
580                           SplitVOpCPU<int32, len_type>);
581 
582 REGISTER_GPU_int32(int32);
583 REGISTER_GPU_int32(int64);
584 
585 #undef REGISTER_GPU_int32
586 
587 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
588 
589 }  // end namespace tensorflow
590