1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
18
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/ops_util.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/kernels/concat_lib.h"
24 #include "tensorflow/core/kernels/split_lib.h"
25 #include "tensorflow/core/platform/status.h"
26
27 namespace tensorflow {
28 namespace concat_split_util {
29
30 typedef Eigen::ThreadPoolDevice CPUDevice;
31 typedef Eigen::GpuDevice GPUDevice;
32
33 // Concatenates 'inputs' into a single tensor along the zeroth dimension.
34 // Requires that all elements of 'inputs' have element type T. Writes to
35 // 'output' using 'context' for the allocation to ensure proper device
36 // placement.
37 template <typename T>
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> inputs,Tensor * output)38 Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
39 Tensor* output) {
40 const int input_dims = inputs[0].dims();
41 const TensorShape& input_shape = inputs[0].shape();
42
43 // Note that we reduce the concat of k-dimensional tensors into a two
44 // dimensional concat. Assuming the dimensions of any input tensor are
45 // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi).
46 std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat;
47 inputs_flat.reserve(inputs.size());
48 int64 output_dim0 = 0;
49 for (size_t i = 0; i < inputs.size(); ++i) {
50 const Tensor& input = inputs[i];
51 if (input.dims() != input_dims) {
52 return errors::InvalidArgument(
53 "Ranks of all input tensors should match: shape[0] = ",
54 input_shape.DebugString(), " vs. shape[", i,
55 "] = ", input.shape().DebugString());
56 }
57 for (int j = 1; j < input_dims; ++j) {
58 if (input.dim_size(j) != input_shape.dim_size(j)) {
59 return errors::InvalidArgument(
60 "Dimensions of inputs should match: shape[0] = ",
61 input_shape.DebugString(), " vs. shape[", i,
62 "] = ", input.shape().DebugString());
63 }
64 }
65 if (input.NumElements() > 0) {
66 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
67 input.shaped<T, 2>({1, input.NumElements()})));
68 }
69 output_dim0 += input.dim_size(0);
70 }
71
72 TensorShape output_shape(input_shape);
73 output_shape.set_dim(0, output_dim0);
74 AllocatorAttributes attr;
75 attr.set_on_host(true);
76 TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value,
77 output_shape, output, attr));
78 if (output->NumElements() > 0) {
79 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
80 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
81 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
82 if (std::is_same<Device, GPUDevice>::value) {
83 ConcatGPU<T>(context, inputs_flat, output, &output_flat);
84 return Status::OK();
85 }
86 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
87 ConcatCPU<T>(context->device(), inputs_flat, &output_flat);
88 }
89
90 return Status::OK();
91 }
92
93 // Same as 'Concat' above, but handles Tensor dtype deduction automatically.
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> inputs,Tensor * output)94 inline Status Concat(OpKernelContext* context,
95 const gtl::ArraySlice<Tensor> inputs, Tensor* output) {
96 const DataType type = inputs[0].dtype();
97 Status concat_status;
98 switch (type) {
99 #define CASE(type) \
100 case DataTypeToEnum<type>::value: \
101 concat_status = Concat<type>(context, inputs, output); \
102 break;
103 TF_CALL_ALL_TYPES(CASE);
104 #undef CASE
105 default:
106 concat_status = errors::InvalidArgument("Unsupported data type: ", type);
107 break;
108 }
109 return concat_status;
110 }
111
112 // The Split*() functions split 'input' with element type T into 'sizes.size()'
113 // tensors along the zeroth dimension, with the ith split having zeroth-
114 // dimension size 'sizes[i]'. They allocate the output tensors using 'context',
115 // for proper device placement.
116
117 // Handles special cases that are cheap. Sets 'done==true' iff it found an
118 // applicable special case and wrote to the outputs. Otherwise acts as a no-op.
119 template <typename T>
SplitEasyCases(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> sizes,std::vector<Tensor> * outputs,bool * done)120 Status SplitEasyCases(OpKernelContext* context, const Tensor& input,
121 const gtl::ArraySlice<int64> sizes,
122 std::vector<Tensor>* outputs, bool* done) {
123 *done = false;
124
125 int64 total_size = 0;
126 for (const int64 size : sizes) {
127 total_size += size;
128 }
129 if (total_size > input.shape().dim_size(0)) {
130 return errors::InvalidArgument(
131 "Sum of split sizes must not exceed dim0-size of input tensor");
132 }
133
134 // Special case 0: trivial 1-way split.
135 if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) {
136 outputs->push_back(input);
137 *done = true;
138 return Status::OK();
139 }
140
141 // Special case 1: input is aligned.
142 if (IsInnerDimsSizeAligned<T>(input.shape())) {
143 int64 position = 0;
144 for (const int64 size : sizes) {
145 outputs->emplace_back(input.Slice(position, position + size));
146 position += size;
147 }
148 *done = true;
149 return Status::OK();
150 }
151
152 return Status::OK();
153 }
154
155 // Handles the general case, on CPU.
156 template <typename T>
SplitCPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> sizes,std::vector<Tensor> * outputs)157 Status SplitCPU(OpKernelContext* context, const Tensor& input,
158 const gtl::ArraySlice<int64> sizes,
159 std::vector<Tensor>* outputs) {
160 int64 suffix_dim_size = 1;
161 for (int i = 1; i < input.shape().dims(); ++i) {
162 suffix_dim_size *= input.shape().dim_size(i);
163 }
164 auto input_reshaped =
165 input.shaped<T, 2>({input.shape().dim_size(0), suffix_dim_size});
166
167 int64 position = 0;
168 for (const int64 size : sizes) {
169 TensorShape output_shape = input.shape();
170 output_shape.set_dim(0, size);
171 Tensor output;
172 AllocatorAttributes attr;
173 attr.set_on_host(true);
174 TF_RETURN_IF_ERROR(
175 context->allocate_temp(input.dtype(), output_shape, &output, attr));
176 auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
177
178 Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{
179 static_cast<Eigen::DenseIndex>(position), 0};
180 Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{
181 static_cast<Eigen::DenseIndex>(size),
182 static_cast<Eigen::DenseIndex>(suffix_dim_size)};
183 functor::Split<CPUDevice, T, 2>()(context->eigen_device<CPUDevice>(),
184 output_shaped, input_reshaped,
185 slice_indices, slice_sizes);
186
187 outputs->emplace_back(output);
188
189 position += size;
190 }
191
192 return Status::OK();
193 }
194
195 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
196 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
197
198 // Handles the general case, on GPU.
199 template <typename T>
SplitGPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)200 Status SplitGPU(OpKernelContext* context, const Tensor& input,
201 const gtl::ArraySlice<int64>& sizes,
202 std::vector<Tensor>* outputs) {
203 // TODO(olston, apassos): Implement this.
204 LOG(FATAL) << "Not yet implemented"; // Crash ok
205 }
206
207 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
208
209 // The outer function that dispatches to the various Split*() functions above.
210 template <typename T>
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> sizes,std::vector<Tensor> * outputs)211 Status Split(OpKernelContext* context, const Tensor& input,
212 const gtl::ArraySlice<int64> sizes, std::vector<Tensor>* outputs) {
213 bool easy_cases_done;
214 TF_RETURN_IF_ERROR(
215 SplitEasyCases<T>(context, input, sizes, outputs, &easy_cases_done));
216 if (easy_cases_done) {
217 return Status::OK();
218 }
219
220 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
221 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
222 // TODO(olston, apassos): Handle non-CPU cases.
223 // return SplitGPU<T>(context, input, sizes, outputs);
224 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
225 return SplitCPU<T>(context, input, sizes, outputs);
226 }
227
228 // Same as 'Split' above, but handles Tensor dtype automatically.
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> sizes,std::vector<Tensor> * outputs)229 inline Status Split(OpKernelContext* context, const Tensor& input,
230 const gtl::ArraySlice<int64> sizes,
231 std::vector<Tensor>* outputs) {
232 const DataType type = input.dtype();
233 Status split_status;
234 switch (type) {
235 #define CASE(type) \
236 case DataTypeToEnum<type>::value: \
237 split_status = Split<type>(context, input, sizes, outputs); \
238 break;
239 TF_CALL_ALL_TYPES(CASE);
240 #undef CASE
241 default:
242 split_status = errors::InvalidArgument("Unsupported data type: ", type);
243 break;
244 }
245 return split_status;
246 }
247
248 } // namespace concat_split_util
249 } // namespace tensorflow
250
251 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
252