1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #include <limits> 19 #include <vector> 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/bounds_check.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_types.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/kernels/concat_lib.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/platform/types.h" 31 32 namespace tensorflow { 33 34 typedef Eigen::ThreadPoolDevice CPUDevice; 35 #if GOOGLE_CUDA 36 typedef Eigen::GpuDevice GPUDevice; 37 #endif // GOOGLE_CUDA 38 #ifdef TENSORFLOW_USE_SYCL 39 typedef Eigen::SyclDevice SYCLDevice; 40 #endif // TENSORFLOW_USE_SYCL 41 42 enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM }; 43 44 // -------------------------------------------------------------------------- 45 template <typename Device, typename T, AxisArgumentName AxisArgName> 46 class ConcatBaseOp : public OpKernel { 47 public: 48 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> 49 ConstMatrixVector; 50 ConcatBaseOp(OpKernelConstruction * c)51 explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {} 52 Compute(OpKernelContext * c)53 void Compute(OpKernelContext* c) override { 54 const Tensor* concat_dim_tensor; 55 const char* axis_attribute_name = 56 AxisArgName == NAME_IS_AXIS ? "axis" : AxisArgName == NAME_IS_CONCAT_DIM 57 ? "concat_dim" 58 : "<invalid>"; 59 OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); 60 OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()), 61 errors::InvalidArgument( 62 axis_attribute_name, 63 " tensor should be a scalar integer, but got shape ", 64 concat_dim_tensor->shape().DebugString())); 65 int64 concat_dim; 66 // In case of ConcatV2, "axis" could be int32 or int64 67 if (AxisArgName == NAME_IS_AXIS) { 68 OP_REQUIRES( 69 c, 70 (concat_dim_tensor->dtype() == DT_INT32 || 71 concat_dim_tensor->dtype() == DT_INT64), 72 errors::InvalidArgument(axis_attribute_name, 73 " tensor should be int32 or int64, but got ", 74 DataTypeString(concat_dim_tensor->dtype()))); 75 } else { 76 OP_REQUIRES(c, (concat_dim_tensor->dtype() == DT_INT32), 77 errors::InvalidArgument( 78 axis_attribute_name, " tensor should be int32, but got ", 79 DataTypeString(concat_dim_tensor->dtype()))); 80 } 81 if (concat_dim_tensor->dtype() == DT_INT32) { 82 concat_dim = 83 internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()()); 84 } else { 85 concat_dim = 86 internal::SubtleMustCopy(concat_dim_tensor->scalar<int64>()()); 87 } 88 89 OpInputList values; 90 OP_REQUIRES_OK(c, c->input_list("values", &values)); 91 const int N = values.size(); 92 const int input_dims = values[0].dims(); 93 const TensorShape& input_shape = values[0].shape(); 94 95 int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; 96 OP_REQUIRES(c, 97 (0 <= axis && axis < input_dims) || 98 (allow_legacy_scalars() && concat_dim == 0), 99 errors::InvalidArgument( 100 "ConcatOp : Expected concatenating dimensions in the range " 101 "[", 102 -input_dims, ", ", input_dims, "), but got ", concat_dim)); 103 // Note that we reduce the concat of n-dimensional tensors into a two 104 // dimensional concat. Assuming the dimensions of any input/output 105 // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along 106 // the dimension indicated with size y0, we flatten it to {x, y}, where y = 107 // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1). 108 ConstMatrixVector inputs_flat; 109 inputs_flat.reserve(N); 110 int64 inputs_flat_dim0 = 1; 111 for (int d = 0; d < axis; ++d) { 112 inputs_flat_dim0 *= input_shape.dim_size(d); 113 } 114 int64 output_concat_dim = 0; 115 const bool input_is_scalar = IsLegacyScalar(input_shape); 116 for (int i = 0; i < N; ++i) { 117 const auto& in = values[i]; 118 const bool in_is_scalar = IsLegacyScalar(in.shape()); 119 OP_REQUIRES( 120 c, in.dims() == input_dims || (input_is_scalar && in_is_scalar), 121 errors::InvalidArgument( 122 "ConcatOp : Ranks of all input tensors should match: shape[0] = ", 123 input_shape.DebugString(), " vs. shape[", i, 124 "] = ", in.shape().DebugString())); 125 for (int j = 0; j < input_dims; ++j) { 126 if (j == axis) { 127 continue; 128 } 129 OP_REQUIRES( 130 c, in.dim_size(j) == input_shape.dim_size(j), 131 errors::InvalidArgument( 132 "ConcatOp : Dimensions of inputs should match: shape[0] = ", 133 input_shape.DebugString(), " vs. shape[", i, 134 "] = ", in.shape().DebugString())); 135 } 136 if (in.NumElements() > 0) { 137 int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; 138 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( 139 in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1}))); 140 } 141 // TODO(irving): Remove check once !allow_legacy_scalars(). 142 output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1; 143 } 144 145 TensorShape output_shape(input_shape); 146 // TODO(irving): Remove rank 0 case once !allow_legacy_scalars(). 147 if (output_shape.dims() == 0) { 148 output_shape.AddDim(output_concat_dim); 149 } else { 150 output_shape.set_dim(axis, output_concat_dim); 151 } 152 Tensor* output = nullptr; 153 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); 154 if (output->NumElements() > 0) { 155 int64 output_dim1 = output->NumElements() / inputs_flat_dim0; 156 auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1}); 157 #if GOOGLE_CUDA 158 if (std::is_same<Device, GPUDevice>::value) { 159 ConcatGPU<T>(c, inputs_flat, output, &output_flat); 160 return; 161 } 162 #endif // GOOGLE_CUDA 163 #ifdef TENSORFLOW_USE_SYCL 164 if (std::is_same<Device, SYCLDevice>::value) { 165 ConcatSYCL<T>(c->eigen_sycl_device(), inputs_flat, &output_flat); 166 return; 167 } 168 #endif // TENSORFLOW_USE_SYCL 169 ConcatCPU<T>(c->device(), inputs_flat, &output_flat); 170 } 171 } 172 }; 173 174 template <typename Device, typename T> 175 using ConcatOp = ConcatBaseOp<Device, T, NAME_IS_CONCAT_DIM>; 176 template <typename Device, typename T> 177 using ConcatV2Op = ConcatBaseOp<Device, T, NAME_IS_AXIS>; 178 179 #define REGISTER_CONCAT(type) \ 180 REGISTER_KERNEL_BUILDER(Name("Concat") \ 181 .Device(DEVICE_CPU) \ 182 .TypeConstraint<type>("T") \ 183 .HostMemory("concat_dim"), \ 184 ConcatOp<CPUDevice, type>) \ 185 REGISTER_KERNEL_BUILDER(Name("ConcatV2") \ 186 .Device(DEVICE_CPU) \ 187 .TypeConstraint<type>("T") \ 188 .HostMemory("axis"), \ 189 ConcatV2Op<CPUDevice, type>) 190 191 TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT); 192 REGISTER_CONCAT(quint8); 193 REGISTER_CONCAT(qint8); 194 REGISTER_CONCAT(quint16); 195 REGISTER_CONCAT(qint16); 196 REGISTER_CONCAT(qint32); 197 198 #undef REGISTER_CONCAT 199 200 #if GOOGLE_CUDA 201 202 #define REGISTER_GPU(type) \ 203 REGISTER_KERNEL_BUILDER(Name("Concat") \ 204 .Device(DEVICE_GPU) \ 205 .TypeConstraint<type>("T") \ 206 .HostMemory("concat_dim"), \ 207 ConcatOp<GPUDevice, type>) \ 208 REGISTER_KERNEL_BUILDER(Name("ConcatV2") \ 209 .Device(DEVICE_GPU) \ 210 .TypeConstraint<type>("T") \ 211 .HostMemory("axis"), \ 212 ConcatV2Op<GPUDevice, type>) 213 214 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); 215 REGISTER_GPU(bfloat16); 216 TF_CALL_uint8(REGISTER_GPU); 217 TF_CALL_complex64(REGISTER_GPU); 218 TF_CALL_complex128(REGISTER_GPU); 219 TF_CALL_int64(REGISTER_GPU); 220 REGISTER_GPU(bool); 221 #undef REGISTER_GPU 222 223 // A special GPU kernel for int32. 224 // TODO(b/25387198): Also enable int32 in device memory. This kernel 225 // registration requires all int32 inputs and outputs to be in host memory. 226 REGISTER_KERNEL_BUILDER(Name("Concat") 227 .Device(DEVICE_GPU) 228 .TypeConstraint<int32>("T") 229 .HostMemory("concat_dim") 230 .HostMemory("values") 231 .HostMemory("output"), 232 ConcatOp<CPUDevice, int32>); 233 REGISTER_KERNEL_BUILDER(Name("ConcatV2") 234 .Device(DEVICE_GPU) 235 .TypeConstraint<int32>("T") 236 .HostMemory("values") 237 .HostMemory("axis") 238 .HostMemory("output"), 239 ConcatV2Op<CPUDevice, int32>); 240 241 #endif // GOOGLE_CUDA 242 243 #ifdef TENSORFLOW_USE_SYCL 244 #define REGISTER_SYCL(type) \ 245 REGISTER_KERNEL_BUILDER(Name("Concat") \ 246 .Device(DEVICE_SYCL) \ 247 .TypeConstraint<type>("T") \ 248 .HostMemory("concat_dim"), \ 249 ConcatOp<SYCLDevice, type>) \ 250 REGISTER_KERNEL_BUILDER(Name("ConcatV2") \ 251 .Device(DEVICE_SYCL) \ 252 .TypeConstraint<type>("T") \ 253 .HostMemory("axis"), \ 254 ConcatV2Op<SYCLDevice, type>) 255 256 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL); 257 258 REGISTER_KERNEL_BUILDER(Name("Concat") 259 .Device(DEVICE_SYCL) 260 .TypeConstraint<int32>("T") 261 .HostMemory("concat_dim") 262 .HostMemory("values") 263 .HostMemory("output"), 264 ConcatOp<CPUDevice, int32>); 265 REGISTER_KERNEL_BUILDER(Name("ConcatV2") 266 .Device(DEVICE_SYCL) 267 .TypeConstraint<int32>("T") 268 .HostMemory("values") 269 .HostMemory("axis") 270 .HostMemory("output"), 271 ConcatV2Op<CPUDevice, int32>); 272 273 #undef REGISTER_SYCL 274 #endif // TENSORFLOW_USE_SYCL 275 276 class ConcatOffsetOp : public OpKernel { 277 public: ConcatOffsetOp(OpKernelConstruction * ctx)278 explicit ConcatOffsetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 279 Compute(OpKernelContext * ctx)280 void Compute(OpKernelContext* ctx) override { 281 const Tensor& concat_dim = ctx->input(0); 282 OP_REQUIRES( 283 ctx, IsLegacyScalar(concat_dim.shape()), 284 errors::InvalidArgument( 285 "Concat dim tensor should be a scalar integer, but got shape ", 286 concat_dim.shape().DebugString())); 287 for (int i = 1; i < ctx->num_inputs(); ++i) { 288 const Tensor& inp = ctx->input(i); 289 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(inp.shape()), 290 errors::InvalidArgument("input ", i, 291 " should be a vector, but got shape ", 292 inp.shape().DebugString())); 293 } 294 // Suppose a Concat() op needs to Concatenate N tensors, each of 295 // which has the same number of dimensions. Their shapes match 296 // except the concat dimension. 297 // 298 // E.g., say, we want to concatenate 3 tensors in the 2nd 299 // dimension, and their shapes are: 300 // 301 // [2, 2, 5, 7] 302 // [2, 3, 5, 7] 303 // [2, 4, 5, 7] 304 // 305 // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape 306 // [2,9,5,7]. We will compute the cumulative sum along the 2nd 307 // dimension to figure out each input's offset in the concatenated 308 // output: 309 // [0, 0, 0, 0] 310 // [0, 2, 0, 0] 311 // [0, 5, 0, 0] 312 const int32 N = ctx->num_inputs() - 1; 313 const Tensor& inp0 = ctx->input(1); 314 auto inp0_vec = inp0.vec<int32>(); 315 const int64 cdim = internal::SubtleMustCopy(concat_dim.scalar<int32>()()); 316 const int64 dims = inp0.NumElements(); 317 int32 axis = cdim < 0 ? cdim + dims : cdim; 318 OP_REQUIRES(ctx, FastBoundsCheck(axis, dims), 319 errors::InvalidArgument("Concat dim is out of range: ", cdim, 320 " vs. ", dims)); 321 int32 offset = 0; 322 for (int i = 0; i < N; ++i) { 323 const Tensor& inp = ctx->input(1 + i); 324 OP_REQUIRES( 325 ctx, dims == inp.NumElements(), 326 errors::InvalidArgument("input ", i, " should contain ", dims, 327 " elements, but got ", inp.NumElements())); 328 auto inp_vec = inp.vec<int32>(); 329 Tensor* out = nullptr; 330 OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out)); 331 auto out_vec = out->vec<int32>(); 332 for (int64 j = 0; j < dims; ++j) { 333 if (j == axis) { 334 out_vec(j) = offset; 335 offset += inp_vec(j); 336 } else { 337 OP_REQUIRES(ctx, (inp0_vec(j) == inp_vec(j)), 338 errors::InvalidArgument( 339 "All dimensions except ", axis, " must match. Input ", 340 i, " has shape [", inp.SummarizeValue(10), 341 "] and doesn't match input 0 with shape [", 342 inp0.SummarizeValue(10), "].")); 343 out_vec(j) = 0; 344 } 345 } 346 } 347 } 348 IsExpensive()349 bool IsExpensive() override { return false; } 350 }; 351 352 REGISTER_KERNEL_BUILDER(Name("ConcatOffset").Device(DEVICE_CPU), 353 ConcatOffsetOp); 354 355 REGISTER_KERNEL_BUILDER(Name("ConcatOffset") 356 .Device(DEVICE_GPU) 357 .HostMemory("concat_dim") 358 .HostMemory("shape") 359 .HostMemory("offset"), 360 ConcatOffsetOp); 361 362 #ifdef TENSORFLOW_USE_SYCL 363 REGISTER_KERNEL_BUILDER(Name("ConcatOffset") 364 .Device(DEVICE_SYCL) 365 .HostMemory("concat_dim") 366 .HostMemory("shape") 367 .HostMemory("offset"), 368 ConcatOffsetOp); 369 #endif // TENSORFLOW_USE_SYCL 370 } // namespace tensorflow 371