• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #include <limits>
19 #include <vector>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/kernels/concat_lib.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace tensorflow {
32 
33 typedef Eigen::ThreadPoolDevice CPUDevice;
34 #if GOOGLE_CUDA
35 typedef Eigen::GpuDevice GPUDevice;
36 #endif  // GOOGLE_CUDA
37 #ifdef TENSORFLOW_USE_SYCL
38 typedef Eigen::SyclDevice SYCLDevice;
39 #endif  // TENSORFLOW_USE_SYCL
40 
41 // --------------------------------------------------------------------------
42 template <typename Device, typename T>
43 class PackOp : public OpKernel {
44  public:
45   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
46       ConstMatrixVector;
47 
PackOp(OpKernelConstruction * context)48   explicit PackOp(OpKernelConstruction* context) : OpKernel(context) {
49     OP_REQUIRES_OK(context, context->GetAttr("axis", &axis_));
50   }
51 
Compute(OpKernelContext * c)52   void Compute(OpKernelContext* c) override {
53     OpInputList values;
54     OP_REQUIRES_OK(c, c->input_list("values", &values));
55     const int num = values.size();
56 
57     // Verify that all input shapes match
58     for (int i = 1; i < num; i++) {
59       OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()),
60                   errors::InvalidArgument(
61                       "Shapes of all inputs must match: values[0].shape = ",
62                       values[0].shape().DebugString(), " != values[", i,
63                       "].shape = ", values[i].shape().DebugString()));
64     }
65 
66     int expanded_num_dims = values[0].dims() + 1;
67     int axis = axis_;
68     if (axis < 0) axis += expanded_num_dims;
69 
70     OP_REQUIRES(c, 0 <= axis && axis < expanded_num_dims,
71                 errors::InvalidArgument("axis = ", axis_, " not in [",
72                                         -expanded_num_dims, ", ",
73                                         expanded_num_dims, ")"));
74 
75     TensorShape output_shape(values[0].shape());
76     output_shape.InsertDim(axis, num);
77 
78     // In the num = 1 case, just reshape the input
79     if (num == 1) {
80       Tensor output;
81       CHECK(output.CopyFrom(values[0], output_shape));
82       c->set_output(0, output);
83       return;
84     }
85 
86     // Allocate output
87     Tensor* output;
88     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
89 
90     int64 before_dim = 1;
91     for (int i = 0; i < axis; ++i) {
92       before_dim *= output_shape.dim_size(i);
93     }
94 
95     int64 after_dim = 1;
96     for (int i = axis + 1; i < output_shape.dims(); ++i) {
97       after_dim *= output_shape.dim_size(i);
98     }
99 
100     const int64 axis_dim = output_shape.dim_size(axis);
101 
102     const int64 output_size = output->NumElements();
103     if (output_size > 0) {
104       auto output_flat =
105           output->shaped<T, 2>({before_dim, after_dim * axis_dim});
106 
107       // Except for shapes, pack is a special case of concat, so we reuse the
108       // same computational kernels.
109       ConstMatrixVector inputs_flat;
110       inputs_flat.reserve(num);
111       for (int i = 0; i < num; ++i) {
112         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
113             values[i].shaped<T, 2>({before_dim, after_dim})));
114       }
115 #if GOOGLE_CUDA
116       if (std::is_same<Device, GPUDevice>::value) {
117         ConcatGPU<T>(c, inputs_flat, output, &output_flat);
118         return;
119       }
120 #endif  // GOOGLE_CUDA
121 #ifdef TENSORFLOW_USE_SYCL
122       if (std::is_same<Device, SYCLDevice>::value) {
123         ConcatSYCL<T>(c->eigen_sycl_device(), inputs_flat, &output_flat);
124         return;
125       }
126 #endif  // TENSORFLOW_USE_SYCL
127       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
128     }
129   }
130 
131  private:
132   int axis_;
133 };
134 
135 #define REGISTER_PACK(type)                                      \
136   REGISTER_KERNEL_BUILDER(                                       \
137       Name("Pack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
138       PackOp<CPUDevice, type>)
139 
140 TF_CALL_ALL_TYPES(REGISTER_PACK);
141 TF_CALL_QUANTIZED_TYPES(REGISTER_PACK);
142 
143 #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION)
144 // Primarily used for SavedModel support on mobile.
145 REGISTER_PACK(string);
146 #endif  // defined(IS_MOBILE_PLATFORM) &&
147         // !defined(SUPPORT_SELECTIVE_REGISTRATION)
148 
149 #undef REGISTER_PACK
150 
151 #if GOOGLE_CUDA
152 
153 #define REGISTER_GPU(type)                                       \
154   REGISTER_KERNEL_BUILDER(                                       \
155       Name("Pack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
156       PackOp<GPUDevice, type>)
157 
158 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
159 TF_CALL_bfloat16(REGISTER_GPU);
160 TF_CALL_int64(REGISTER_GPU);
161 TF_CALL_int16(REGISTER_GPU);
162 TF_CALL_bool(REGISTER_GPU);
163 #undef REGISTER_GPU
164 
165 // A special GPU kernel for int32.
166 // TODO(b/25387198): Also enable int32 in device memory. This kernel
167 // registration requires all int32 inputs and outputs to be in host memory.
168 REGISTER_KERNEL_BUILDER(Name("Pack")
169                             .Device(DEVICE_GPU)
170                             .HostMemory("values")
171                             .HostMemory("output")
172                             .TypeConstraint<int32>("T"),
173                         PackOp<CPUDevice, int32>);
174 
175 #endif  // GOOGLE_CUDA
176 
177 #ifdef TENSORFLOW_USE_SYCL
178 #define REGISTER_SYCL(type)                                       \
179   REGISTER_KERNEL_BUILDER(                                        \
180       Name("Pack").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
181       PackOp<SYCLDevice, type>)
182 
183 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
184 REGISTER_KERNEL_BUILDER(Name("Pack")
185                             .Device(DEVICE_SYCL)
186                             .HostMemory("values")
187                             .HostMemory("output")
188                             .TypeConstraint<int32>("T"),
189                         PackOp<CPUDevice, int32>);
190 #undef REGISTER_SYCL
191 #endif  // TENSORFLOW_USE_SYCL
192 }  // namespace tensorflow
193