1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/array_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/transpose_op.h"
21
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/kernels/transpose_functor.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/logging.h"
31
32 namespace tensorflow {
33
34 // inv = InvertPermutationOp(T<int32/int64> p) takes a permutation of
35 // integers 0, 1, ..., n - 1 and returns the inverted
36 // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
37 //
38 // REQUIRES: input is a vector of int32 or int64.
39 // REQUIRES: input is a permutation of 0, 1, ..., n-1.
40
41 template <typename T>
42 class InvertPermutationOp : public OpKernel {
43 public:
InvertPermutationOp(OpKernelConstruction * context)44 explicit InvertPermutationOp(OpKernelConstruction* context)
45 : OpKernel(context) {}
46
Compute(OpKernelContext * context)47 void Compute(OpKernelContext* context) override {
48 const Tensor& input = context->input(0);
49 OP_REQUIRES(
50 context, TensorShapeUtils::IsVector(input.shape()),
51 errors::InvalidArgument("invert_permutation expects a 1D vector."));
52 auto Tin = input.vec<T>();
53 OP_REQUIRES(context,
54 FastBoundsCheck(Tin.size(), std::numeric_limits<int32>::max()),
55 errors::InvalidArgument("permutation of nonnegative int32s "
56 "must have <= int32 max elements"));
57 const T N = static_cast<T>(Tin.size()); // Safe: bounds-checked above.
58 Tensor* output = nullptr;
59 OP_REQUIRES_OK(context,
60 context->allocate_output(0, input.shape(), &output));
61 auto Tout = output->vec<T>();
62 std::fill_n(Tout.data(), N, -1);
63 for (int i = 0; i < N; ++i) {
64 const T d = internal::SubtleMustCopy(Tin(i));
65 OP_REQUIRES(context, FastBoundsCheck(d, N),
66 errors::InvalidArgument(d, " is not between 0 and ", N));
67 OP_REQUIRES(context, Tout(d) == -1,
68 errors::InvalidArgument(d, " is duplicated in the input."));
69 Tout(d) = i;
70 }
71 }
72 };
73
74 REGISTER_KERNEL_BUILDER(
75 Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int32>("T"),
76 InvertPermutationOp<int32>);
77 REGISTER_KERNEL_BUILDER(
78 Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int64>("T"),
79 InvertPermutationOp<int64>);
80
81 REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
82 .Device(DEVICE_GPU)
83 .TypeConstraint<int32>("T")
84 .HostMemory("x")
85 .HostMemory("y"),
86 InvertPermutationOp<int32>);
87 REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
88 .Device(DEVICE_GPU)
89 .TypeConstraint<int64>("T")
90 .HostMemory("x")
91 .HostMemory("y"),
92 InvertPermutationOp<int64>);
93
94 #ifdef TENSORFLOW_USE_SYCL
95 REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
96 .Device(DEVICE_SYCL)
97 .TypeConstraint<int32>("T")
98 .HostMemory("x")
99 .HostMemory("y"),
100 InvertPermutationOp<int32>);
101 REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
102 .Device(DEVICE_SYCL)
103 .TypeConstraint<int64>("T")
104 .HostMemory("x")
105 .HostMemory("y"),
106 InvertPermutationOp<int64>);
107 #endif // TENSORFLOW_USE_SYCL
108
109 namespace {
110 template <typename Tperm>
PermutationHelper(const Tensor & perm,const int dims,std::vector<int32> * permutation)111 Status PermutationHelper(const Tensor& perm, const int dims,
112 std::vector<int32>* permutation) {
113 auto Vperm = perm.vec<Tperm>();
114 if (dims != Vperm.size()) {
115 return errors::InvalidArgument("transpose expects a vector of size ", dims,
116 ". But input(1) is a vector of size ",
117 Vperm.size());
118 }
119 // using volatile instead of SubtleMustCopy here so that the
120 // asynchrony boundary is permutation.
121 const volatile Tperm* perm_begin =
122 reinterpret_cast<const volatile Tperm*>(Vperm.data());
123 *permutation = std::vector<int32>(perm_begin, perm_begin + dims);
124
125 return Status::OK();
126 }
127 } // namespace
128
129 // output = TransposeOp(T<any> input, T<int32> perm) takes a tensor
130 // of type T and rank N, and a permutation of 0, 1, ..., N-1. It
131 // shuffles the dimensions of the input tensor according to permutation.
132 //
133 // Specifically, the returned tensor output meets the following condition:
134 // 1) output.dims() == input.dims();
135 // 2) output.dim_size(i) == input.dim_size(perm[i]);
136 // 3) output.tensor<T, N>(i_0, i_1, ..., i_N-1) ==
137 // input.tensor<T, N>(j_0, j_1, ..., j_N-1),
138 // where i_s == j_{perm[s]}
139 //
140 // REQUIRES: perm is a vector of int32.
141 // REQUIRES: input.dims() == perm.size().
142 // REQUIRES: perm is a permutation.
143
Compute(OpKernelContext * ctx)144 void TransposeOp::Compute(OpKernelContext* ctx) {
145 const Tensor& input = ctx->input(0);
146 const Tensor& perm = ctx->input(1);
147 // Preliminary validation of sizes.
148 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm.shape()),
149 errors::InvalidArgument("perm must be a vector, not ",
150 perm.shape().DebugString()));
151
152 // Although Tperm may be an int64 type, an int32 is sufficient to hold
153 // dimension range values, so the narrowing here should be safe.
154 std::vector<int32> permutation;
155 const int dims = input.dims();
156 if (perm.dtype() == DT_INT32) {
157 OP_REQUIRES_OK(ctx, PermutationHelper<int32>(perm, dims, &permutation));
158 } else {
159 OP_REQUIRES_OK(ctx, PermutationHelper<int64>(perm, dims, &permutation));
160 }
161 TensorShape shape;
162
163 // Check whether permutation is a permutation of integers of [0 .. dims).
164 gtl::InlinedVector<bool, 8> bits(dims);
165 bool is_identity = true;
166 for (int i = 0; i < dims; ++i) {
167 const int32 d = permutation[i];
168 OP_REQUIRES(
169 ctx, 0 <= d && d < dims,
170 errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
171 bits[d] = true;
172 const auto dim_size = input.dim_size(d);
173 shape.AddDim(dim_size);
174 if (d != i) {
175 is_identity = false;
176 }
177 }
178 for (int i = 0; i < dims; ++i) {
179 OP_REQUIRES(
180 ctx, bits[i],
181 errors::InvalidArgument(i, " is missing from {",
182 str_util::Join(permutation, ","), "}."));
183 }
184
185 // 0-D, 1-D, and identity transposes do nothing.
186 if (!IsConjugate() && (dims <= 1 || is_identity)) {
187 ctx->set_output(0, input);
188 return;
189 } else if (!IsConjugate() && internal::NonSingletonDimensionsAlign(
190 input.shape(), permutation)) {
191 Tensor output;
192 OP_REQUIRES(ctx, output.CopyFrom(input, shape),
193 errors::Unknown("Error reshaping Tensor."));
194 ctx->set_output(0, output);
195 return;
196 }
197
198 Tensor* output = nullptr;
199 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
200 if (shape.num_elements() > 0) {
201 OP_REQUIRES_OK(ctx, DoTranspose(ctx, input, permutation, output));
202 }
203 }
204
DoTranspose(OpKernelContext * ctx,const Tensor & in,gtl::ArraySlice<int32> perm,Tensor * out)205 Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
206 gtl::ArraySlice<int32> perm, Tensor* out) {
207 typedef Eigen::ThreadPoolDevice CPUDevice;
208 return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
209 out);
210 }
211
DoTranspose(OpKernelContext * ctx,const Tensor & in,gtl::ArraySlice<int32> perm,Tensor * out)212 Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
213 const Tensor& in,
214 gtl::ArraySlice<int32> perm,
215 Tensor* out) {
216 typedef Eigen::ThreadPoolDevice CPUDevice;
217 return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(), in,
218 perm, out);
219 }
220
221 #if defined(INTEL_MKL) && defined(ENABLE_MKL)
222 #define REGISTER(T) \
223 REGISTER_KERNEL_BUILDER(Name("Transpose") \
224 .Device(DEVICE_CPU) \
225 .TypeConstraint<T>("T") \
226 .HostMemory("perm"), \
227 MklTransposeCpuOp); \
228 REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
229 .Device(DEVICE_CPU) \
230 .TypeConstraint<T>("T") \
231 .HostMemory("perm"), \
232 MklConjugateTransposeCpuOp);
233
234 #else // INTEL_MKL && ENABLE_MKL
235 #define REGISTER(T) \
236 REGISTER_KERNEL_BUILDER(Name("Transpose") \
237 .Device(DEVICE_CPU) \
238 .TypeConstraint<T>("T") \
239 .HostMemory("perm"), \
240 TransposeCpuOp); \
241 REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
242 .Device(DEVICE_CPU) \
243 .TypeConstraint<T>("T") \
244 .HostMemory("perm"), \
245 ConjugateTransposeCpuOp);
246 #endif // INTEL_MKL && ENABLE_MKL
247
TF_CALL_ALL_TYPES(REGISTER)248 TF_CALL_ALL_TYPES(REGISTER)
249 #undef REGISTER
250
251 #if GOOGLE_CUDA
252 Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
253 gtl::ArraySlice<int32> perm, Tensor* out) {
254 typedef Eigen::GpuDevice GPUDevice;
255 return ::tensorflow::DoTranspose(ctx->eigen_device<GPUDevice>(), in, perm,
256 out);
257 }
DoTranspose(OpKernelContext * ctx,const Tensor & in,gtl::ArraySlice<int32> perm,Tensor * out)258 Status ConjugateTransposeGpuOp::DoTranspose(OpKernelContext* ctx,
259 const Tensor& in,
260 gtl::ArraySlice<int32> perm,
261 Tensor* out) {
262 typedef Eigen::GpuDevice GPUDevice;
263 return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<GPUDevice>(), in,
264 perm, out);
265 }
266
267 #define REGISTER(T) \
268 REGISTER_KERNEL_BUILDER(Name("Transpose") \
269 .Device(DEVICE_GPU) \
270 .TypeConstraint<T>("T") \
271 .HostMemory("perm"), \
272 TransposeGpuOp); \
273 REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
274 .Device(DEVICE_GPU) \
275 .TypeConstraint<T>("T") \
276 .HostMemory("perm"), \
277 ConjugateTransposeGpuOp);
278 TF_CALL_POD_TYPES(REGISTER);
279 #undef REGISTER
280 #endif
281
282 #ifdef TENSORFLOW_USE_SYCL
DoTranspose(OpKernelContext * ctx,const Tensor & in,gtl::ArraySlice<int32> perm,Tensor * out)283 Status TransposeSyclOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
284 gtl::ArraySlice<int32> perm, Tensor* out) {
285 typedef Eigen::SyclDevice SYCLDevice;
286 return ::tensorflow::DoTranspose(ctx->eigen_device<SYCLDevice>(), in, perm,
287 out);
288 }
DoTranspose(OpKernelContext * ctx,const Tensor & in,gtl::ArraySlice<int32> perm,Tensor * out)289 Status ConjugateTransposeSyclOp::DoTranspose(OpKernelContext* ctx,
290 const Tensor& in,
291 gtl::ArraySlice<int32> perm,
292 Tensor* out) {
293 typedef Eigen::SyclDevice SYCLDevice;
294 return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<SYCLDevice>(), in,
295 perm, out);
296 }
297 #define REGISTER(T) \
298 REGISTER_KERNEL_BUILDER(Name("Transpose") \
299 .Device(DEVICE_SYCL) \
300 .TypeConstraint<T>("T") \
301 .HostMemory("perm"), \
302 TransposeSyclOp); \
303 REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
304 .Device(DEVICE_SYCL) \
305 .TypeConstraint<T>("T") \
306 .HostMemory("perm"), \
307 ConjugateTransposeSyclOp);
308 TF_CALL_POD_TYPES(REGISTER);
309 #undef REGISTER
310 #endif
311 } // namespace tensorflow
312