• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/transpose_op.h"
21 
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/kernels/transpose_functor.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/logging.h"
31 
32 namespace tensorflow {
33 
34 // inv = InvertPermutationOp(T<int32/int64> p) takes a permutation of
35 // integers 0, 1, ..., n - 1 and returns the inverted
36 // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
37 //
38 // REQUIRES: input is a vector of int32 or int64.
39 // REQUIRES: input is a permutation of 0, 1, ..., n-1.
40 
41 template <typename T>
42 class InvertPermutationOp : public OpKernel {
43  public:
InvertPermutationOp(OpKernelConstruction * context)44   explicit InvertPermutationOp(OpKernelConstruction* context)
45       : OpKernel(context) {}
46 
Compute(OpKernelContext * context)47   void Compute(OpKernelContext* context) override {
48     const Tensor& input = context->input(0);
49     OP_REQUIRES(
50         context, TensorShapeUtils::IsVector(input.shape()),
51         errors::InvalidArgument("invert_permutation expects a 1D vector."));
52     auto Tin = input.vec<T>();
53     OP_REQUIRES(context,
54                 FastBoundsCheck(Tin.size(), std::numeric_limits<int32>::max()),
55                 errors::InvalidArgument("permutation of nonnegative int32s "
56                                         "must have <= int32 max elements"));
57     const T N = static_cast<T>(Tin.size());  // Safe: bounds-checked above.
58     Tensor* output = nullptr;
59     OP_REQUIRES_OK(context,
60                    context->allocate_output(0, input.shape(), &output));
61     auto Tout = output->vec<T>();
62     std::fill_n(Tout.data(), N, -1);
63     for (int i = 0; i < N; ++i) {
64       const T d = internal::SubtleMustCopy(Tin(i));
65       OP_REQUIRES(context, FastBoundsCheck(d, N),
66                   errors::InvalidArgument(d, " is not between 0 and ", N));
67       OP_REQUIRES(context, Tout(d) == -1,
68                   errors::InvalidArgument(d, " is duplicated in the input."));
69       Tout(d) = i;
70     }
71   }
72 };
73 
74 REGISTER_KERNEL_BUILDER(
75     Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int32>("T"),
76     InvertPermutationOp<int32>);
77 REGISTER_KERNEL_BUILDER(
78     Name("InvertPermutation").Device(DEVICE_CPU).TypeConstraint<int64>("T"),
79     InvertPermutationOp<int64>);
80 
81 REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
82                             .Device(DEVICE_GPU)
83                             .TypeConstraint<int32>("T")
84                             .HostMemory("x")
85                             .HostMemory("y"),
86                         InvertPermutationOp<int32>);
87 REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
88                             .Device(DEVICE_GPU)
89                             .TypeConstraint<int64>("T")
90                             .HostMemory("x")
91                             .HostMemory("y"),
92                         InvertPermutationOp<int64>);
93 
94 #ifdef TENSORFLOW_USE_SYCL
95 REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
96                             .Device(DEVICE_SYCL)
97                             .TypeConstraint<int32>("T")
98                             .HostMemory("x")
99                             .HostMemory("y"),
100                         InvertPermutationOp<int32>);
101 REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
102                             .Device(DEVICE_SYCL)
103                             .TypeConstraint<int64>("T")
104                             .HostMemory("x")
105                             .HostMemory("y"),
106                         InvertPermutationOp<int64>);
107 #endif  // TENSORFLOW_USE_SYCL
108 
109 namespace {
110 template <typename Tperm>
PermutationHelper(const Tensor & perm,const int dims,std::vector<int32> * permutation)111 Status PermutationHelper(const Tensor& perm, const int dims,
112                          std::vector<int32>* permutation) {
113   auto Vperm = perm.vec<Tperm>();
114   if (dims != Vperm.size()) {
115     return errors::InvalidArgument("transpose expects a vector of size ", dims,
116                                    ". But input(1) is a vector of size ",
117                                    Vperm.size());
118   }
119   // using volatile instead of SubtleMustCopy here so that the
120   // asynchrony boundary is permutation.
121   const volatile Tperm* perm_begin =
122       reinterpret_cast<const volatile Tperm*>(Vperm.data());
123   *permutation = std::vector<int32>(perm_begin, perm_begin + dims);
124 
125   return Status::OK();
126 }
127 }  // namespace
128 
129 // output = TransposeOp(T<any> input, T<int32> perm) takes a tensor
130 // of type T and rank N, and a permutation of 0, 1, ..., N-1. It
131 // shuffles the dimensions of the input tensor according to permutation.
132 //
133 // Specifically, the returned tensor output meets the following condition:
134 // 1) output.dims() == input.dims();
135 // 2) output.dim_size(i) == input.dim_size(perm[i]);
136 // 3) output.tensor<T, N>(i_0, i_1, ..., i_N-1) ==
137 //      input.tensor<T, N>(j_0, j_1, ..., j_N-1),
138 //    where i_s == j_{perm[s]}
139 //
140 // REQUIRES: perm is a vector of int32.
141 // REQUIRES: input.dims() == perm.size().
142 // REQUIRES: perm is a permutation.
143 
Compute(OpKernelContext * ctx)144 void TransposeOp::Compute(OpKernelContext* ctx) {
145   const Tensor& input = ctx->input(0);
146   const Tensor& perm = ctx->input(1);
147   // Preliminary validation of sizes.
148   OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm.shape()),
149               errors::InvalidArgument("perm must be a vector, not ",
150                                       perm.shape().DebugString()));
151 
152   // Although Tperm may be an int64 type, an int32 is sufficient to hold
153   // dimension range values, so the narrowing here should be safe.
154   std::vector<int32> permutation;
155   const int dims = input.dims();
156   if (perm.dtype() == DT_INT32) {
157     OP_REQUIRES_OK(ctx, PermutationHelper<int32>(perm, dims, &permutation));
158   } else {
159     OP_REQUIRES_OK(ctx, PermutationHelper<int64>(perm, dims, &permutation));
160   }
161   TensorShape shape;
162 
163   // Check whether permutation is a permutation of integers of [0 .. dims).
164   gtl::InlinedVector<bool, 8> bits(dims);
165   bool is_identity = true;
166   for (int i = 0; i < dims; ++i) {
167     const int32 d = permutation[i];
168     OP_REQUIRES(
169         ctx, 0 <= d && d < dims,
170         errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
171     bits[d] = true;
172     const auto dim_size = input.dim_size(d);
173     shape.AddDim(dim_size);
174     if (d != i) {
175       is_identity = false;
176     }
177   }
178   for (int i = 0; i < dims; ++i) {
179     OP_REQUIRES(
180         ctx, bits[i],
181         errors::InvalidArgument(i, " is missing from {",
182                                 str_util::Join(permutation, ","), "}."));
183   }
184 
185   // 0-D, 1-D, and identity transposes do nothing.
186   if (!IsConjugate() && (dims <= 1 || is_identity)) {
187     ctx->set_output(0, input);
188     return;
189   } else if (!IsConjugate() && internal::NonSingletonDimensionsAlign(
190                                    input.shape(), permutation)) {
191     Tensor output;
192     OP_REQUIRES(ctx, output.CopyFrom(input, shape),
193                 errors::Unknown("Error reshaping Tensor."));
194     ctx->set_output(0, output);
195     return;
196   }
197 
198   Tensor* output = nullptr;
199   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
200   if (shape.num_elements() > 0) {
201     OP_REQUIRES_OK(ctx, DoTranspose(ctx, input, permutation, output));
202   }
203 }
204 
DoTranspose(OpKernelContext * ctx,const Tensor & in,gtl::ArraySlice<int32> perm,Tensor * out)205 Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
206                                    gtl::ArraySlice<int32> perm, Tensor* out) {
207   typedef Eigen::ThreadPoolDevice CPUDevice;
208   return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
209                                    out);
210 }
211 
DoTranspose(OpKernelContext * ctx,const Tensor & in,gtl::ArraySlice<int32> perm,Tensor * out)212 Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
213                                             const Tensor& in,
214                                             gtl::ArraySlice<int32> perm,
215                                             Tensor* out) {
216   typedef Eigen::ThreadPoolDevice CPUDevice;
217   return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(), in,
218                                             perm, out);
219 }
220 
221 #if defined(INTEL_MKL) && defined(ENABLE_MKL)
222 #define REGISTER(T)                                   \
223   REGISTER_KERNEL_BUILDER(Name("Transpose")           \
224                               .Device(DEVICE_CPU)     \
225                               .TypeConstraint<T>("T") \
226                               .HostMemory("perm"),    \
227                           MklTransposeCpuOp);         \
228   REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")  \
229                               .Device(DEVICE_CPU)     \
230                               .TypeConstraint<T>("T") \
231                               .HostMemory("perm"),    \
232                           MklConjugateTransposeCpuOp);
233 
234 #else  // INTEL_MKL && ENABLE_MKL
235 #define REGISTER(T)                                   \
236   REGISTER_KERNEL_BUILDER(Name("Transpose")           \
237                               .Device(DEVICE_CPU)     \
238                               .TypeConstraint<T>("T") \
239                               .HostMemory("perm"),    \
240                           TransposeCpuOp);            \
241   REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")  \
242                               .Device(DEVICE_CPU)     \
243                               .TypeConstraint<T>("T") \
244                               .HostMemory("perm"),    \
245                           ConjugateTransposeCpuOp);
246 #endif  // INTEL_MKL && ENABLE_MKL
247 
TF_CALL_ALL_TYPES(REGISTER)248 TF_CALL_ALL_TYPES(REGISTER)
249 #undef REGISTER
250 
251 #if GOOGLE_CUDA
252 Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
253                                    gtl::ArraySlice<int32> perm, Tensor* out) {
254   typedef Eigen::GpuDevice GPUDevice;
255   return ::tensorflow::DoTranspose(ctx->eigen_device<GPUDevice>(), in, perm,
256                                    out);
257 }
DoTranspose(OpKernelContext * ctx,const Tensor & in,gtl::ArraySlice<int32> perm,Tensor * out)258 Status ConjugateTransposeGpuOp::DoTranspose(OpKernelContext* ctx,
259                                             const Tensor& in,
260                                             gtl::ArraySlice<int32> perm,
261                                             Tensor* out) {
262   typedef Eigen::GpuDevice GPUDevice;
263   return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<GPUDevice>(), in,
264                                             perm, out);
265 }
266 
267 #define REGISTER(T)                                   \
268   REGISTER_KERNEL_BUILDER(Name("Transpose")           \
269                               .Device(DEVICE_GPU)     \
270                               .TypeConstraint<T>("T") \
271                               .HostMemory("perm"),    \
272                           TransposeGpuOp);            \
273   REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")  \
274                               .Device(DEVICE_GPU)     \
275                               .TypeConstraint<T>("T") \
276                               .HostMemory("perm"),    \
277                           ConjugateTransposeGpuOp);
278 TF_CALL_POD_TYPES(REGISTER);
279 #undef REGISTER
280 #endif
281 
282 #ifdef TENSORFLOW_USE_SYCL
DoTranspose(OpKernelContext * ctx,const Tensor & in,gtl::ArraySlice<int32> perm,Tensor * out)283 Status TransposeSyclOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
284                                     gtl::ArraySlice<int32> perm, Tensor* out) {
285   typedef Eigen::SyclDevice SYCLDevice;
286   return ::tensorflow::DoTranspose(ctx->eigen_device<SYCLDevice>(), in, perm,
287                                    out);
288 }
DoTranspose(OpKernelContext * ctx,const Tensor & in,gtl::ArraySlice<int32> perm,Tensor * out)289 Status ConjugateTransposeSyclOp::DoTranspose(OpKernelContext* ctx,
290                                              const Tensor& in,
291                                              gtl::ArraySlice<int32> perm,
292                                              Tensor* out) {
293   typedef Eigen::SyclDevice SYCLDevice;
294   return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<SYCLDevice>(), in,
295                                             perm, out);
296 }
297 #define REGISTER(T)                                   \
298   REGISTER_KERNEL_BUILDER(Name("Transpose")           \
299                               .Device(DEVICE_SYCL)    \
300                               .TypeConstraint<T>("T") \
301                               .HostMemory("perm"),    \
302                           TransposeSyclOp);           \
303   REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")  \
304                               .Device(DEVICE_SYCL)    \
305                               .TypeConstraint<T>("T") \
306                               .HostMemory("perm"),    \
307                           ConjugateTransposeSyclOp);
308 TF_CALL_POD_TYPES(REGISTER);
309 #undef REGISTER
310 #endif
311 }  // namespace tensorflow
312