• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <complex>
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/kernels/ops_util.h"
24 #include "tensorflow/core/kernels/transpose_functor.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 
29 typedef Eigen::ThreadPoolDevice CPUDevice;
30 
31 namespace tensorflow {
32 namespace {
33 
34 template <typename T, bool conjugate>
TransposeSimple(const CPUDevice & device,const Tensor & in,const gtl::ArraySlice<int32> perm,Tensor * out)35 void TransposeSimple(const CPUDevice& device, const Tensor& in,
36                      const gtl::ArraySlice<int32> perm, Tensor* out) {
37   const int ndims = in.dims();
38   gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
39   gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape());
40   const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
41   T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
42   auto transpose_fn = [=, &in_strides, &out_strides, &perm](int64 begin,
43                                                             int64 end) {
44     for (int64 o_idx = begin; o_idx < end; ++o_idx) {
45       int64 i_idx = 0;
46       int64 t = o_idx;
47       for (int i = 0; i < ndims; ++i) {
48         const int64 ratio = t / out_strides[i];
49         t -= ratio * out_strides[i];
50         i_idx += ratio * in_strides[perm[i]];
51       }
52       if (conjugate) {
53         q[o_idx] = Eigen::numext::conj(p[i_idx]);
54       } else {
55         q[o_idx] = p[i_idx];
56       }
57     }
58   };
59   double cycles_per_element =
60       (conjugate ? 1 : 0) + ndims * (Eigen::TensorOpCost::DivCost<int64>() +
61                                      2 * Eigen::TensorOpCost::MulCost<int64>() +
62                                      2 * Eigen::TensorOpCost::AddCost<int64>());
63   Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T),
64                            /*bytes_stored=*/sizeof(T), cycles_per_element);
65   device.parallelFor(in.NumElements(), cost, std::move(transpose_fn));
66 }
67 
68 }  // namespace
69 
70 template <typename T, bool conjugate>
71 struct Transpose<CPUDevice, T, conjugate> {
runtensorflow::Transpose72   static void run(const CPUDevice& d, const Tensor& in,
73                   const gtl::ArraySlice<int32> perm, Tensor* out) {
74     switch (in.dims()) {
75       case 2:
76         internal::TransposeUsingEigen<CPUDevice, T, 2>(d, in, perm, conjugate,
77                                                        out);
78         break;
79       case 3:
80         internal::TransposeUsingEigen<CPUDevice, T, 3>(d, in, perm, conjugate,
81                                                        out);
82         break;
83       case 4:
84         internal::TransposeUsingEigen<CPUDevice, T, 4>(d, in, perm, conjugate,
85                                                        out);
86         break;
87       case 5:
88         internal::TransposeUsingEigen<CPUDevice, T, 5>(d, in, perm, conjugate,
89                                                        out);
90         break;
91       case 6:
92         internal::TransposeUsingEigen<CPUDevice, T, 6>(d, in, perm, conjugate,
93                                                        out);
94         break;
95       case 7:
96         internal::TransposeUsingEigen<CPUDevice, T, 7>(d, in, perm, conjugate,
97                                                        out);
98         break;
99       case 8:
100         internal::TransposeUsingEigen<CPUDevice, T, 8>(d, in, perm, conjugate,
101                                                        out);
102         break;
103       default:
104         TransposeSimple<T, conjugate>(d, in, perm, out);
105         break;
106     }
107   }
108 };
109 
110 #define INSTANTIATE(DEVICE)                                                 \
111   template <>                                                               \
112   Status DoTranspose(const DEVICE& device, const Tensor& in,                \
113                      const gtl::ArraySlice<int32> perm, Tensor* out) {      \
114     return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, \
115                                      out);                                  \
116   }                                                                         \
117   template <>                                                               \
118   Status DoConjugateTranspose(const DEVICE& device, const Tensor& in,       \
119                               const gtl::ArraySlice<int32> perm,            \
120                               Tensor* out) {                                \
121     return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true,  \
122                                      out);                                  \
123   }                                                                         \
124   template <>                                                               \
125   Status DoMatrixTranspose(const DEVICE& device, const Tensor& in,          \
126                            Tensor* out) {                                   \
127     return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, \
128                                            out);                            \
129   }                                                                         \
130   template <>                                                               \
131   Status DoConjugateMatrixTranspose(const DEVICE& device, const Tensor& in, \
132                                     Tensor* out) {                          \
133     return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true,  \
134                                            out);                            \
135   }
136 
137 INSTANTIATE(CPUDevice)
138 
139 
140 }  // namespace tensorflow
141