1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include <complex>
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/kernels/ops_util.h"
24 #include "tensorflow/core/kernels/transpose_functor.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28
29 typedef Eigen::ThreadPoolDevice CPUDevice;
30
31 namespace tensorflow {
32 namespace {
33
34 template <typename T, bool conjugate>
TransposeSimple(const CPUDevice & device,const Tensor & in,const gtl::ArraySlice<int32> perm,Tensor * out)35 void TransposeSimple(const CPUDevice& device, const Tensor& in,
36 const gtl::ArraySlice<int32> perm, Tensor* out) {
37 const int ndims = in.dims();
38 gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
39 gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape());
40 const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
41 T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
42 auto transpose_fn = [=, &in_strides, &out_strides, &perm](int64 begin,
43 int64 end) {
44 for (int64 o_idx = begin; o_idx < end; ++o_idx) {
45 int64 i_idx = 0;
46 int64 t = o_idx;
47 for (int i = 0; i < ndims; ++i) {
48 const int64 ratio = t / out_strides[i];
49 t -= ratio * out_strides[i];
50 i_idx += ratio * in_strides[perm[i]];
51 }
52 if (conjugate) {
53 q[o_idx] = Eigen::numext::conj(p[i_idx]);
54 } else {
55 q[o_idx] = p[i_idx];
56 }
57 }
58 };
59 double cycles_per_element =
60 (conjugate ? 1 : 0) + ndims * (Eigen::TensorOpCost::DivCost<int64>() +
61 2 * Eigen::TensorOpCost::MulCost<int64>() +
62 2 * Eigen::TensorOpCost::AddCost<int64>());
63 Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T),
64 /*bytes_stored=*/sizeof(T), cycles_per_element);
65 device.parallelFor(in.NumElements(), cost, std::move(transpose_fn));
66 }
67
68 } // namespace
69
70 template <typename T, bool conjugate>
71 struct Transpose<CPUDevice, T, conjugate> {
runtensorflow::Transpose72 static void run(const CPUDevice& d, const Tensor& in,
73 const gtl::ArraySlice<int32> perm, Tensor* out) {
74 switch (in.dims()) {
75 case 2:
76 internal::TransposeUsingEigen<CPUDevice, T, 2>(d, in, perm, conjugate,
77 out);
78 break;
79 case 3:
80 internal::TransposeUsingEigen<CPUDevice, T, 3>(d, in, perm, conjugate,
81 out);
82 break;
83 case 4:
84 internal::TransposeUsingEigen<CPUDevice, T, 4>(d, in, perm, conjugate,
85 out);
86 break;
87 case 5:
88 internal::TransposeUsingEigen<CPUDevice, T, 5>(d, in, perm, conjugate,
89 out);
90 break;
91 case 6:
92 internal::TransposeUsingEigen<CPUDevice, T, 6>(d, in, perm, conjugate,
93 out);
94 break;
95 case 7:
96 internal::TransposeUsingEigen<CPUDevice, T, 7>(d, in, perm, conjugate,
97 out);
98 break;
99 case 8:
100 internal::TransposeUsingEigen<CPUDevice, T, 8>(d, in, perm, conjugate,
101 out);
102 break;
103 default:
104 TransposeSimple<T, conjugate>(d, in, perm, out);
105 break;
106 }
107 }
108 };
109
110 #define INSTANTIATE(DEVICE) \
111 template <> \
112 Status DoTranspose(const DEVICE& device, const Tensor& in, \
113 const gtl::ArraySlice<int32> perm, Tensor* out) { \
114 return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, \
115 out); \
116 } \
117 template <> \
118 Status DoConjugateTranspose(const DEVICE& device, const Tensor& in, \
119 const gtl::ArraySlice<int32> perm, \
120 Tensor* out) { \
121 return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true, \
122 out); \
123 } \
124 template <> \
125 Status DoMatrixTranspose(const DEVICE& device, const Tensor& in, \
126 Tensor* out) { \
127 return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, \
128 out); \
129 } \
130 template <> \
131 Status DoConjugateMatrixTranspose(const DEVICE& device, const Tensor& in, \
132 Tensor* out) { \
133 return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true, \
134 out); \
135 }
136
137 INSTANTIATE(CPUDevice)
138
139
140 } // namespace tensorflow
141