• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #include <limits>
19 #include <vector>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_types.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/kernels/concat_lib.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace tensorflow {
33 
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 #if GOOGLE_CUDA
36 typedef Eigen::GpuDevice GPUDevice;
37 #endif  // GOOGLE_CUDA
38 #ifdef TENSORFLOW_USE_SYCL
39 typedef Eigen::SyclDevice SYCLDevice;
40 #endif  // TENSORFLOW_USE_SYCL
41 
42 enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
43 
44 // --------------------------------------------------------------------------
45 template <typename Device, typename T, AxisArgumentName AxisArgName>
46 class ConcatBaseOp : public OpKernel {
47  public:
48   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
49       ConstMatrixVector;
50 
ConcatBaseOp(OpKernelConstruction * c)51   explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
52 
Compute(OpKernelContext * c)53   void Compute(OpKernelContext* c) override {
54     const Tensor* concat_dim_tensor;
55     const char* axis_attribute_name =
56         AxisArgName == NAME_IS_AXIS ? "axis" : AxisArgName == NAME_IS_CONCAT_DIM
57                                                    ? "concat_dim"
58                                                    : "<invalid>";
59     OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
60     OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()),
61                 errors::InvalidArgument(
62                     axis_attribute_name,
63                     " tensor should be a scalar integer, but got shape ",
64                     concat_dim_tensor->shape().DebugString()));
65     int64 concat_dim;
66     // In case of ConcatV2, "axis" could be int32 or int64
67     if (AxisArgName == NAME_IS_AXIS) {
68       OP_REQUIRES(
69           c,
70           (concat_dim_tensor->dtype() == DT_INT32 ||
71            concat_dim_tensor->dtype() == DT_INT64),
72           errors::InvalidArgument(axis_attribute_name,
73                                   " tensor should be int32 or int64, but got ",
74                                   DataTypeString(concat_dim_tensor->dtype())));
75     } else {
76       OP_REQUIRES(c, (concat_dim_tensor->dtype() == DT_INT32),
77                   errors::InvalidArgument(
78                       axis_attribute_name, " tensor should be int32, but got ",
79                       DataTypeString(concat_dim_tensor->dtype())));
80     }
81     if (concat_dim_tensor->dtype() == DT_INT32) {
82       concat_dim =
83           internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
84     } else {
85       concat_dim =
86           internal::SubtleMustCopy(concat_dim_tensor->scalar<int64>()());
87     }
88 
89     OpInputList values;
90     OP_REQUIRES_OK(c, c->input_list("values", &values));
91     const int N = values.size();
92     const int input_dims = values[0].dims();
93     const TensorShape& input_shape = values[0].shape();
94 
95     int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
96     OP_REQUIRES(c,
97                 (0 <= axis && axis < input_dims) ||
98                     (allow_legacy_scalars() && concat_dim == 0),
99                 errors::InvalidArgument(
100                     "ConcatOp : Expected concatenating dimensions in the range "
101                     "[",
102                     -input_dims, ", ", input_dims, "), but got ", concat_dim));
103     // Note that we reduce the concat of n-dimensional tensors into a two
104     // dimensional concat. Assuming the dimensions of any input/output
105     // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
106     // the dimension indicated with size y0, we flatten it to {x, y}, where y =
107     // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1).
108     ConstMatrixVector inputs_flat;
109     inputs_flat.reserve(N);
110     int64 inputs_flat_dim0 = 1;
111     for (int d = 0; d < axis; ++d) {
112       inputs_flat_dim0 *= input_shape.dim_size(d);
113     }
114     int64 output_concat_dim = 0;
115     const bool input_is_scalar = IsLegacyScalar(input_shape);
116     for (int i = 0; i < N; ++i) {
117       const auto& in = values[i];
118       const bool in_is_scalar = IsLegacyScalar(in.shape());
119       OP_REQUIRES(
120           c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
121           errors::InvalidArgument(
122               "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
123               input_shape.DebugString(), " vs. shape[", i,
124               "] = ", in.shape().DebugString()));
125       for (int j = 0; j < input_dims; ++j) {
126         if (j == axis) {
127           continue;
128         }
129         OP_REQUIRES(
130             c, in.dim_size(j) == input_shape.dim_size(j),
131             errors::InvalidArgument(
132                 "ConcatOp : Dimensions of inputs should match: shape[0] = ",
133                 input_shape.DebugString(), " vs. shape[", i,
134                 "] = ", in.shape().DebugString()));
135       }
136       if (in.NumElements() > 0) {
137         int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
138         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
139             in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
140       }
141       // TODO(irving): Remove check once !allow_legacy_scalars().
142       output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
143     }
144 
145     TensorShape output_shape(input_shape);
146     // TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
147     if (output_shape.dims() == 0) {
148       output_shape.AddDim(output_concat_dim);
149     } else {
150       output_shape.set_dim(axis, output_concat_dim);
151     }
152     Tensor* output = nullptr;
153     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
154     if (output->NumElements() > 0) {
155       int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
156       auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
157 #if GOOGLE_CUDA
158       if (std::is_same<Device, GPUDevice>::value) {
159         ConcatGPU<T>(c, inputs_flat, output, &output_flat);
160         return;
161       }
162 #endif  // GOOGLE_CUDA
163 #ifdef TENSORFLOW_USE_SYCL
164       if (std::is_same<Device, SYCLDevice>::value) {
165         ConcatSYCL<T>(c->eigen_sycl_device(), inputs_flat, &output_flat);
166         return;
167       }
168 #endif  // TENSORFLOW_USE_SYCL
169       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
170     }
171   }
172 };
173 
174 template <typename Device, typename T>
175 using ConcatOp = ConcatBaseOp<Device, T, NAME_IS_CONCAT_DIM>;
176 template <typename Device, typename T>
177 using ConcatV2Op = ConcatBaseOp<Device, T, NAME_IS_AXIS>;
178 
179 #define REGISTER_CONCAT(type)                            \
180   REGISTER_KERNEL_BUILDER(Name("Concat")                 \
181                               .Device(DEVICE_CPU)        \
182                               .TypeConstraint<type>("T") \
183                               .HostMemory("concat_dim"), \
184                           ConcatOp<CPUDevice, type>)     \
185   REGISTER_KERNEL_BUILDER(Name("ConcatV2")               \
186                               .Device(DEVICE_CPU)        \
187                               .TypeConstraint<type>("T") \
188                               .HostMemory("axis"),       \
189                           ConcatV2Op<CPUDevice, type>)
190 
191 TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT);
192 REGISTER_CONCAT(quint8);
193 REGISTER_CONCAT(qint8);
194 REGISTER_CONCAT(quint16);
195 REGISTER_CONCAT(qint16);
196 REGISTER_CONCAT(qint32);
197 
198 #undef REGISTER_CONCAT
199 
200 #if GOOGLE_CUDA
201 
202 #define REGISTER_GPU(type)                               \
203   REGISTER_KERNEL_BUILDER(Name("Concat")                 \
204                               .Device(DEVICE_GPU)        \
205                               .TypeConstraint<type>("T") \
206                               .HostMemory("concat_dim"), \
207                           ConcatOp<GPUDevice, type>)     \
208   REGISTER_KERNEL_BUILDER(Name("ConcatV2")               \
209                               .Device(DEVICE_GPU)        \
210                               .TypeConstraint<type>("T") \
211                               .HostMemory("axis"),       \
212                           ConcatV2Op<GPUDevice, type>)
213 
214 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
215 REGISTER_GPU(bfloat16);
216 TF_CALL_uint8(REGISTER_GPU);
217 TF_CALL_complex64(REGISTER_GPU);
218 TF_CALL_complex128(REGISTER_GPU);
219 TF_CALL_int64(REGISTER_GPU);
220 REGISTER_GPU(bool);
221 #undef REGISTER_GPU
222 
223 // A special GPU kernel for int32.
224 // TODO(b/25387198): Also enable int32 in device memory. This kernel
225 // registration requires all int32 inputs and outputs to be in host memory.
226 REGISTER_KERNEL_BUILDER(Name("Concat")
227                             .Device(DEVICE_GPU)
228                             .TypeConstraint<int32>("T")
229                             .HostMemory("concat_dim")
230                             .HostMemory("values")
231                             .HostMemory("output"),
232                         ConcatOp<CPUDevice, int32>);
233 REGISTER_KERNEL_BUILDER(Name("ConcatV2")
234                             .Device(DEVICE_GPU)
235                             .TypeConstraint<int32>("T")
236                             .HostMemory("values")
237                             .HostMemory("axis")
238                             .HostMemory("output"),
239                         ConcatV2Op<CPUDevice, int32>);
240 
241 #endif  // GOOGLE_CUDA
242 
243 #ifdef TENSORFLOW_USE_SYCL
244 #define REGISTER_SYCL(type)                              \
245   REGISTER_KERNEL_BUILDER(Name("Concat")                 \
246                               .Device(DEVICE_SYCL)       \
247                               .TypeConstraint<type>("T") \
248                               .HostMemory("concat_dim"), \
249                           ConcatOp<SYCLDevice, type>)    \
250   REGISTER_KERNEL_BUILDER(Name("ConcatV2")               \
251                               .Device(DEVICE_SYCL)       \
252                               .TypeConstraint<type>("T") \
253                               .HostMemory("axis"),       \
254                           ConcatV2Op<SYCLDevice, type>)
255 
256 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
257 
258 REGISTER_KERNEL_BUILDER(Name("Concat")
259                             .Device(DEVICE_SYCL)
260                             .TypeConstraint<int32>("T")
261                             .HostMemory("concat_dim")
262                             .HostMemory("values")
263                             .HostMemory("output"),
264                         ConcatOp<CPUDevice, int32>);
265 REGISTER_KERNEL_BUILDER(Name("ConcatV2")
266                             .Device(DEVICE_SYCL)
267                             .TypeConstraint<int32>("T")
268                             .HostMemory("values")
269                             .HostMemory("axis")
270                             .HostMemory("output"),
271                         ConcatV2Op<CPUDevice, int32>);
272 
273 #undef REGISTER_SYCL
274 #endif  // TENSORFLOW_USE_SYCL
275 
276 class ConcatOffsetOp : public OpKernel {
277  public:
ConcatOffsetOp(OpKernelConstruction * ctx)278   explicit ConcatOffsetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
279 
Compute(OpKernelContext * ctx)280   void Compute(OpKernelContext* ctx) override {
281     const Tensor& concat_dim = ctx->input(0);
282     OP_REQUIRES(
283         ctx, IsLegacyScalar(concat_dim.shape()),
284         errors::InvalidArgument(
285             "Concat dim tensor should be a scalar integer, but got shape ",
286             concat_dim.shape().DebugString()));
287     for (int i = 1; i < ctx->num_inputs(); ++i) {
288       const Tensor& inp = ctx->input(i);
289       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(inp.shape()),
290                   errors::InvalidArgument("input ", i,
291                                           " should be a vector, but got shape ",
292                                           inp.shape().DebugString()));
293     }
294     // Suppose a Concat() op needs to Concatenate N tensors, each of
295     // which has the same number of dimensions.  Their shapes match
296     // except the concat dimension.
297     //
298     // E.g., say, we want to concatenate 3 tensors in the 2nd
299     // dimension, and their shapes are:
300     //
301     //  [2, 2, 5, 7]
302     //  [2, 3, 5, 7]
303     //  [2, 4, 5, 7]
304     //
305     // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape
306     // [2,9,5,7]. We will compute the cumulative sum along the 2nd
307     // dimension to figure out each input's offset in the concatenated
308     // output:
309     //  [0, 0, 0, 0]
310     //  [0, 2, 0, 0]
311     //  [0, 5, 0, 0]
312     const int32 N = ctx->num_inputs() - 1;
313     const Tensor& inp0 = ctx->input(1);
314     auto inp0_vec = inp0.vec<int32>();
315     const int64 cdim = internal::SubtleMustCopy(concat_dim.scalar<int32>()());
316     const int64 dims = inp0.NumElements();
317     int32 axis = cdim < 0 ? cdim + dims : cdim;
318     OP_REQUIRES(ctx, FastBoundsCheck(axis, dims),
319                 errors::InvalidArgument("Concat dim is out of range: ", cdim,
320                                         " vs. ", dims));
321     int32 offset = 0;
322     for (int i = 0; i < N; ++i) {
323       const Tensor& inp = ctx->input(1 + i);
324       OP_REQUIRES(
325           ctx, dims == inp.NumElements(),
326           errors::InvalidArgument("input ", i, " should contain ", dims,
327                                   " elements, but got ", inp.NumElements()));
328       auto inp_vec = inp.vec<int32>();
329       Tensor* out = nullptr;
330       OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out));
331       auto out_vec = out->vec<int32>();
332       for (int64 j = 0; j < dims; ++j) {
333         if (j == axis) {
334           out_vec(j) = offset;
335           offset += inp_vec(j);
336         } else {
337           OP_REQUIRES(ctx, (inp0_vec(j) == inp_vec(j)),
338                       errors::InvalidArgument(
339                           "All dimensions except ", axis, " must match. Input ",
340                           i, " has shape [", inp.SummarizeValue(10),
341                           "] and doesn't match input 0 with shape [",
342                           inp0.SummarizeValue(10), "]."));
343           out_vec(j) = 0;
344         }
345       }
346     }
347   }
348 
IsExpensive()349   bool IsExpensive() override { return false; }
350 };
351 
352 REGISTER_KERNEL_BUILDER(Name("ConcatOffset").Device(DEVICE_CPU),
353                         ConcatOffsetOp);
354 
355 REGISTER_KERNEL_BUILDER(Name("ConcatOffset")
356                             .Device(DEVICE_GPU)
357                             .HostMemory("concat_dim")
358                             .HostMemory("shape")
359                             .HostMemory("offset"),
360                         ConcatOffsetOp);
361 
362 #ifdef TENSORFLOW_USE_SYCL
363 REGISTER_KERNEL_BUILDER(Name("ConcatOffset")
364                             .Device(DEVICE_SYCL)
365                             .HostMemory("concat_dim")
366                             .HostMemory("shape")
367                             .HostMemory("offset"),
368                         ConcatOffsetOp);
369 #endif  // TENSORFLOW_USE_SYCL
370 }  // namespace tensorflow
371