• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/ops_util.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/kernels/concat_lib.h"
24 #include "tensorflow/core/kernels/split_lib.h"
25 #include "tensorflow/core/platform/status.h"
26 
27 namespace tensorflow {
28 namespace concat_split_util {
29 
30 typedef Eigen::ThreadPoolDevice CPUDevice;
31 typedef Eigen::GpuDevice GPUDevice;
32 
33 // Concatenates 'inputs' into a single tensor along the zeroth dimension.
34 // Requires that all elements of 'inputs' have element type T. Writes to
35 // 'output' using 'context' for the allocation to ensure proper device
36 // placement.
37 template <typename T>
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> inputs,Tensor * output)38 Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
39               Tensor* output) {
40   const int input_dims = inputs[0].dims();
41   const TensorShape& input_shape = inputs[0].shape();
42 
43   // Note that we reduce the concat of k-dimensional tensors into a two
44   // dimensional concat. Assuming the dimensions of any input tensor are
45   // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi).
46   std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat;
47   inputs_flat.reserve(inputs.size());
48   int64 output_dim0 = 0;
49   for (size_t i = 0; i < inputs.size(); ++i) {
50     const Tensor& input = inputs[i];
51     if (input.dims() != input_dims) {
52       return errors::InvalidArgument(
53           "Ranks of all input tensors should match: shape[0] = ",
54           input_shape.DebugString(), " vs. shape[", i,
55           "] = ", input.shape().DebugString());
56     }
57     for (int j = 1; j < input_dims; ++j) {
58       if (input.dim_size(j) != input_shape.dim_size(j)) {
59         return errors::InvalidArgument(
60             "Dimensions of inputs should match: shape[0] = ",
61             input_shape.DebugString(), " vs. shape[", i,
62             "] = ", input.shape().DebugString());
63       }
64     }
65     if (input.NumElements() > 0) {
66       inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
67           input.shaped<T, 2>({1, input.NumElements()})));
68     }
69     output_dim0 += input.dim_size(0);
70   }
71 
72   TensorShape output_shape(input_shape);
73   output_shape.set_dim(0, output_dim0);
74   AllocatorAttributes attr;
75   attr.set_on_host(true);
76   TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<T>::value,
77                                             output_shape, output, attr));
78   if (output->NumElements() > 0) {
79     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
80 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
81     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
82     if (std::is_same<Device, GPUDevice>::value) {
83       ConcatGPU<T>(context, inputs_flat, output, &output_flat);
84       return Status::OK();
85     }
86 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
87     ConcatCPU<T>(context->device(), inputs_flat, &output_flat);
88   }
89 
90   return Status::OK();
91 }
92 
93 // Same as 'Concat' above, but handles Tensor dtype deduction automatically.
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> inputs,Tensor * output)94 inline Status Concat(OpKernelContext* context,
95                      const gtl::ArraySlice<Tensor> inputs, Tensor* output) {
96   const DataType type = inputs[0].dtype();
97   Status concat_status;
98   switch (type) {
99 #define CASE(type)                                         \
100   case DataTypeToEnum<type>::value:                        \
101     concat_status = Concat<type>(context, inputs, output); \
102     break;
103     TF_CALL_ALL_TYPES(CASE);
104 #undef CASE
105     default:
106       concat_status = errors::InvalidArgument("Unsupported data type: ", type);
107       break;
108   }
109   return concat_status;
110 }
111 
112 // The Split*() functions split 'input' with element type T into 'sizes.size()'
113 // tensors along the zeroth dimension, with the ith split having zeroth-
114 // dimension size 'sizes[i]'. They allocate the output tensors using 'context',
115 // for proper device placement.
116 
117 // Handles special cases that are cheap. Sets 'done==true' iff it found an
118 // applicable special case and wrote to the outputs. Otherwise acts as a no-op.
119 template <typename T>
SplitEasyCases(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> sizes,std::vector<Tensor> * outputs,bool * done)120 Status SplitEasyCases(OpKernelContext* context, const Tensor& input,
121                       const gtl::ArraySlice<int64> sizes,
122                       std::vector<Tensor>* outputs, bool* done) {
123   *done = false;
124 
125   int64 total_size = 0;
126   for (const int64 size : sizes) {
127     total_size += size;
128   }
129   if (total_size > input.shape().dim_size(0)) {
130     return errors::InvalidArgument(
131         "Sum of split sizes must not exceed dim0-size of input tensor");
132   }
133 
134   // Special case 0: trivial 1-way split.
135   if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) {
136     outputs->push_back(input);
137     *done = true;
138     return Status::OK();
139   }
140 
141   // Special case 1: input is aligned.
142   if (IsInnerDimsSizeAligned<T>(input.shape())) {
143     int64 position = 0;
144     for (const int64 size : sizes) {
145       outputs->emplace_back(input.Slice(position, position + size));
146       position += size;
147     }
148     *done = true;
149     return Status::OK();
150   }
151 
152   return Status::OK();
153 }
154 
155 // Handles the general case, on CPU.
156 template <typename T>
SplitCPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> sizes,std::vector<Tensor> * outputs)157 Status SplitCPU(OpKernelContext* context, const Tensor& input,
158                 const gtl::ArraySlice<int64> sizes,
159                 std::vector<Tensor>* outputs) {
160   int64 suffix_dim_size = 1;
161   for (int i = 1; i < input.shape().dims(); ++i) {
162     suffix_dim_size *= input.shape().dim_size(i);
163   }
164   auto input_reshaped =
165       input.shaped<T, 2>({input.shape().dim_size(0), suffix_dim_size});
166 
167   int64 position = 0;
168   for (const int64 size : sizes) {
169     TensorShape output_shape = input.shape();
170     output_shape.set_dim(0, size);
171     Tensor output;
172     AllocatorAttributes attr;
173     attr.set_on_host(true);
174     TF_RETURN_IF_ERROR(
175         context->allocate_temp(input.dtype(), output_shape, &output, attr));
176     auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
177 
178     Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{
179         static_cast<Eigen::DenseIndex>(position), 0};
180     Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{
181         static_cast<Eigen::DenseIndex>(size),
182         static_cast<Eigen::DenseIndex>(suffix_dim_size)};
183     functor::Split<CPUDevice, T, 2>()(context->eigen_device<CPUDevice>(),
184                                       output_shaped, input_reshaped,
185                                       slice_indices, slice_sizes);
186 
187     outputs->emplace_back(output);
188 
189     position += size;
190   }
191 
192   return Status::OK();
193 }
194 
195 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
196     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
197 
198 // Handles the general case, on GPU.
199 template <typename T>
SplitGPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)200 Status SplitGPU(OpKernelContext* context, const Tensor& input,
201                 const gtl::ArraySlice<int64>& sizes,
202                 std::vector<Tensor>* outputs) {
203   // TODO(olston, apassos): Implement this.
204   LOG(FATAL) << "Not yet implemented";  // Crash ok
205 }
206 
207 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
208 
209 // The outer function that dispatches to the various Split*() functions above.
210 template <typename T>
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> sizes,std::vector<Tensor> * outputs)211 Status Split(OpKernelContext* context, const Tensor& input,
212              const gtl::ArraySlice<int64> sizes, std::vector<Tensor>* outputs) {
213   bool easy_cases_done;
214   TF_RETURN_IF_ERROR(
215       SplitEasyCases<T>(context, input, sizes, outputs, &easy_cases_done));
216   if (easy_cases_done) {
217     return Status::OK();
218   }
219 
220 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
221     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
222 // TODO(olston, apassos): Handle non-CPU cases.
223 // return SplitGPU<T>(context, input, sizes, outputs);
224 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
225   return SplitCPU<T>(context, input, sizes, outputs);
226 }
227 
228 // Same as 'Split' above, but handles Tensor dtype automatically.
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> sizes,std::vector<Tensor> * outputs)229 inline Status Split(OpKernelContext* context, const Tensor& input,
230                     const gtl::ArraySlice<int64> sizes,
231                     std::vector<Tensor>* outputs) {
232   const DataType type = input.dtype();
233   Status split_status;
234   switch (type) {
235 #define CASE(type)                                              \
236   case DataTypeToEnum<type>::value:                             \
237     split_status = Split<type>(context, input, sizes, outputs); \
238     break;
239     TF_CALL_ALL_TYPES(CASE);
240 #undef CASE
241     default:
242       split_status = errors::InvalidArgument("Unsupported data type: ", type);
243       break;
244   }
245   return split_status;
246 }
247 
248 }  // namespace concat_split_util
249 }  // namespace tensorflow
250 
251 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_
252