1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include <complex>
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/kernels/ops_util.h"
24 #include "tensorflow/core/kernels/transpose_functor.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28
29 typedef Eigen::ThreadPoolDevice CPUDevice;
30
31 namespace tensorflow {
32 namespace {
33
34 template <typename T, bool conjugate>
TransposeSimple(const CPUDevice & device,const Tensor & in,const gtl::ArraySlice<int32> perm,Tensor * out)35 void TransposeSimple(const CPUDevice& device, const Tensor& in,
36 const gtl::ArraySlice<int32> perm, Tensor* out) {
37 const int ndims = in.dims();
38 gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
39 gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape());
40 const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
41 T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
42 auto transpose_fn = [=, &in_strides, &out_strides, &perm](int64 begin,
43 int64 end) {
44 for (int64 o_idx = begin; o_idx < end; ++o_idx) {
45 int64 i_idx = 0;
46 int64 t = o_idx;
47 for (int i = 0; i < ndims; ++i) {
48 const int64 ratio = t / out_strides[i];
49 t -= ratio * out_strides[i];
50 i_idx += ratio * in_strides[perm[i]];
51 }
52 if (conjugate) {
53 q[o_idx] = Eigen::numext::conj(p[i_idx]);
54 } else {
55 q[o_idx] = p[i_idx];
56 }
57 }
58 };
59 double cycles_per_element =
60 (conjugate ? 1 : 0) + ndims * (Eigen::TensorOpCost::DivCost<int64>() +
61 2 * Eigen::TensorOpCost::MulCost<int64>() +
62 2 * Eigen::TensorOpCost::AddCost<int64>());
63 Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T),
64 /*bytes_stored=*/sizeof(T), cycles_per_element);
65 device.parallelFor(in.NumElements(), cost, std::move(transpose_fn));
66 }
67
68 } // namespace
69
70 template <typename T, bool conjugate>
71 struct Transpose<CPUDevice, T, conjugate> {
runtensorflow::Transpose72 static void run(const CPUDevice& d, const Tensor& in,
73 const gtl::ArraySlice<int32> perm, Tensor* out) {
74 switch (in.dims()) {
75 case 2:
76 internal::TransposeUsingEigen<CPUDevice, T, 2>(d, in, perm, conjugate,
77 out);
78 break;
79 case 3:
80 internal::TransposeUsingEigen<CPUDevice, T, 3>(d, in, perm, conjugate,
81 out);
82 break;
83 case 4:
84 internal::TransposeUsingEigen<CPUDevice, T, 4>(d, in, perm, conjugate,
85 out);
86 break;
87 case 5:
88 internal::TransposeUsingEigen<CPUDevice, T, 5>(d, in, perm, conjugate,
89 out);
90 break;
91 case 6:
92 internal::TransposeUsingEigen<CPUDevice, T, 6>(d, in, perm, conjugate,
93 out);
94 break;
95 case 7:
96 internal::TransposeUsingEigen<CPUDevice, T, 7>(d, in, perm, conjugate,
97 out);
98 break;
99 case 8:
100 internal::TransposeUsingEigen<CPUDevice, T, 8>(d, in, perm, conjugate,
101 out);
102 break;
103 default:
104 TransposeSimple<T, conjugate>(d, in, perm, out);
105 break;
106 }
107 }
108 };
109
110 #define INSTANTIATE(DEVICE) \
111 template <> \
112 Status DoTranspose(const DEVICE& device, const Tensor& in, \
113 const gtl::ArraySlice<int32> perm, Tensor* out) { \
114 return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, \
115 out); \
116 } \
117 template <> \
118 Status DoConjugateTranspose(const DEVICE& device, const Tensor& in, \
119 const gtl::ArraySlice<int32> perm, \
120 Tensor* out) { \
121 return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true, \
122 out); \
123 } \
124 template <> \
125 Status DoMatrixTranspose(const DEVICE& device, const Tensor& in, \
126 Tensor* out) { \
127 return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, \
128 out); \
129 } \
130 template <> \
131 Status DoConjugateMatrixTranspose(const DEVICE& device, const Tensor& in, \
132 Tensor* out) { \
133 return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true, \
134 out); \
135 }
136
137 INSTANTIATE(CPUDevice)
138
139 #ifdef TENSORFLOW_USE_SYCL
140 typedef Eigen::SyclDevice SYCLDevice;
141
142 namespace internal {
143 template <typename T>
TransposeSYCL(const SYCLDevice & d,const Tensor & in,const gtl::ArraySlice<int32> perm,bool conjugate,Tensor * out)144 void TransposeSYCL(const SYCLDevice& d, const Tensor& in,
145 const gtl::ArraySlice<int32> perm, bool conjugate,
146 Tensor* out) {
147 switch (in.dims()) {
148 case 1:
149 TransposeUsingEigen<SYCLDevice, T, 1>(d, in, perm, conjugate, out);
150 break;
151 case 2:
152 TransposeUsingEigen<SYCLDevice, T, 2>(d, in, perm, conjugate, out);
153 break;
154 case 3:
155 TransposeUsingEigen<SYCLDevice, T, 3>(d, in, perm, conjugate, out);
156 break;
157 case 4:
158 TransposeUsingEigen<SYCLDevice, T, 4>(d, in, perm, conjugate, out);
159 break;
160 case 5:
161 TransposeUsingEigen<SYCLDevice, T, 5>(d, in, perm, conjugate, out);
162 break;
163 case 6:
164 TransposeUsingEigen<SYCLDevice, T, 6>(d, in, perm, conjugate, out);
165 break;
166 case 7:
167 TransposeUsingEigen<SYCLDevice, T, 7>(d, in, perm, conjugate, out);
168 break;
169 case 8:
170 TransposeUsingEigen<SYCLDevice, T, 8>(d, in, perm, conjugate, out);
171 break;
172 default:
173 LOG(FATAL) << "Unsupported TransposeUsingEigen for: " << in.dims();
174 break;
175 }
176 }
177
178 } // namespace internal
179
180 template <typename T, bool conjugate>
181 struct Transpose<SYCLDevice, T, conjugate> {
runtensorflow::Transpose182 static void run(const SYCLDevice& d, const Tensor& in,
183 const gtl::ArraySlice<int32> perm, Tensor* out) {
184 internal::TransposeSycl(d, in, perm, conjugate, out);
185 }
186 };
187
188 template <bool conjugate>
189 struct Transpose<SYCLDevice, string, conjugate> {
runtensorflow::Transpose190 static void run(const SYCLDevice& d, const Tensor& in,
191 const gtl::ArraySlice<int32> perm, Tensor* out) {
192 LOG(FATAL) << "DT_STRING not supported on SYCL device.";
193 }
194 };
195
196 // Explicit instantiation.
197 template struct Transpose<SYCLDevice, string, false>;
198
199 INSTANTIATE(SYCLDevice)
200 #undef INSTANTIATE
201
202 #endif // TENSORFLOW_USE_SYCL
203
204 } // namespace tensorflow
205