• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <complex>
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/kernels/ops_util.h"
24 #include "tensorflow/core/kernels/transpose_functor.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 
29 typedef Eigen::ThreadPoolDevice CPUDevice;
30 
31 namespace tensorflow {
32 namespace {
33 
34 template <typename T, bool conjugate>
TransposeSimple(const CPUDevice & device,const Tensor & in,const gtl::ArraySlice<int32> perm,Tensor * out)35 void TransposeSimple(const CPUDevice& device, const Tensor& in,
36                      const gtl::ArraySlice<int32> perm, Tensor* out) {
37   const int ndims = in.dims();
38   gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
39   gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape());
40   const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
41   T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
42   auto transpose_fn = [=, &in_strides, &out_strides, &perm](int64 begin,
43                                                             int64 end) {
44     for (int64 o_idx = begin; o_idx < end; ++o_idx) {
45       int64 i_idx = 0;
46       int64 t = o_idx;
47       for (int i = 0; i < ndims; ++i) {
48         const int64 ratio = t / out_strides[i];
49         t -= ratio * out_strides[i];
50         i_idx += ratio * in_strides[perm[i]];
51       }
52       if (conjugate) {
53         q[o_idx] = Eigen::numext::conj(p[i_idx]);
54       } else {
55         q[o_idx] = p[i_idx];
56       }
57     }
58   };
59   double cycles_per_element =
60       (conjugate ? 1 : 0) + ndims * (Eigen::TensorOpCost::DivCost<int64>() +
61                                      2 * Eigen::TensorOpCost::MulCost<int64>() +
62                                      2 * Eigen::TensorOpCost::AddCost<int64>());
63   Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T),
64                            /*bytes_stored=*/sizeof(T), cycles_per_element);
65   device.parallelFor(in.NumElements(), cost, std::move(transpose_fn));
66 }
67 
68 }  // namespace
69 
70 template <typename T, bool conjugate>
71 struct Transpose<CPUDevice, T, conjugate> {
runtensorflow::Transpose72   static void run(const CPUDevice& d, const Tensor& in,
73                   const gtl::ArraySlice<int32> perm, Tensor* out) {
74     switch (in.dims()) {
75       case 2:
76         internal::TransposeUsingEigen<CPUDevice, T, 2>(d, in, perm, conjugate,
77                                                        out);
78         break;
79       case 3:
80         internal::TransposeUsingEigen<CPUDevice, T, 3>(d, in, perm, conjugate,
81                                                        out);
82         break;
83       case 4:
84         internal::TransposeUsingEigen<CPUDevice, T, 4>(d, in, perm, conjugate,
85                                                        out);
86         break;
87       case 5:
88         internal::TransposeUsingEigen<CPUDevice, T, 5>(d, in, perm, conjugate,
89                                                        out);
90         break;
91       case 6:
92         internal::TransposeUsingEigen<CPUDevice, T, 6>(d, in, perm, conjugate,
93                                                        out);
94         break;
95       case 7:
96         internal::TransposeUsingEigen<CPUDevice, T, 7>(d, in, perm, conjugate,
97                                                        out);
98         break;
99       case 8:
100         internal::TransposeUsingEigen<CPUDevice, T, 8>(d, in, perm, conjugate,
101                                                        out);
102         break;
103       default:
104         TransposeSimple<T, conjugate>(d, in, perm, out);
105         break;
106     }
107   }
108 };
109 
110 #define INSTANTIATE(DEVICE)                                                 \
111   template <>                                                               \
112   Status DoTranspose(const DEVICE& device, const Tensor& in,                \
113                      const gtl::ArraySlice<int32> perm, Tensor* out) {      \
114     return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, \
115                                      out);                                  \
116   }                                                                         \
117   template <>                                                               \
118   Status DoConjugateTranspose(const DEVICE& device, const Tensor& in,       \
119                               const gtl::ArraySlice<int32> perm,            \
120                               Tensor* out) {                                \
121     return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true,  \
122                                      out);                                  \
123   }                                                                         \
124   template <>                                                               \
125   Status DoMatrixTranspose(const DEVICE& device, const Tensor& in,          \
126                            Tensor* out) {                                   \
127     return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, \
128                                            out);                            \
129   }                                                                         \
130   template <>                                                               \
131   Status DoConjugateMatrixTranspose(const DEVICE& device, const Tensor& in, \
132                                     Tensor* out) {                          \
133     return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true,  \
134                                            out);                            \
135   }
136 
137 INSTANTIATE(CPUDevice)
138 
139 #ifdef TENSORFLOW_USE_SYCL
140 typedef Eigen::SyclDevice SYCLDevice;
141 
142 namespace internal {
143 template <typename T>
TransposeSYCL(const SYCLDevice & d,const Tensor & in,const gtl::ArraySlice<int32> perm,bool conjugate,Tensor * out)144 void TransposeSYCL(const SYCLDevice& d, const Tensor& in,
145                    const gtl::ArraySlice<int32> perm, bool conjugate,
146                    Tensor* out) {
147   switch (in.dims()) {
148     case 1:
149       TransposeUsingEigen<SYCLDevice, T, 1>(d, in, perm, conjugate, out);
150       break;
151     case 2:
152       TransposeUsingEigen<SYCLDevice, T, 2>(d, in, perm, conjugate, out);
153       break;
154     case 3:
155       TransposeUsingEigen<SYCLDevice, T, 3>(d, in, perm, conjugate, out);
156       break;
157     case 4:
158       TransposeUsingEigen<SYCLDevice, T, 4>(d, in, perm, conjugate, out);
159       break;
160     case 5:
161       TransposeUsingEigen<SYCLDevice, T, 5>(d, in, perm, conjugate, out);
162       break;
163     case 6:
164       TransposeUsingEigen<SYCLDevice, T, 6>(d, in, perm, conjugate, out);
165       break;
166     case 7:
167       TransposeUsingEigen<SYCLDevice, T, 7>(d, in, perm, conjugate, out);
168       break;
169     case 8:
170       TransposeUsingEigen<SYCLDevice, T, 8>(d, in, perm, conjugate, out);
171       break;
172     default:
173       LOG(FATAL) << "Unsupported TransposeUsingEigen for: " << in.dims();
174       break;
175   }
176 }
177 
178 }  // namespace internal
179 
180 template <typename T, bool conjugate>
181 struct Transpose<SYCLDevice, T, conjugate> {
runtensorflow::Transpose182   static void run(const SYCLDevice& d, const Tensor& in,
183                   const gtl::ArraySlice<int32> perm, Tensor* out) {
184     internal::TransposeSycl(d, in, perm, conjugate, out);
185   }
186 };
187 
188 template <bool conjugate>
189 struct Transpose<SYCLDevice, string, conjugate> {
runtensorflow::Transpose190   static void run(const SYCLDevice& d, const Tensor& in,
191                   const gtl::ArraySlice<int32> perm, Tensor* out) {
192     LOG(FATAL) << "DT_STRING not supported on SYCL device.";
193   }
194 };
195 
196 // Explicit instantiation.
197 template struct Transpose<SYCLDevice, string, false>;
198 
199 INSTANTIATE(SYCLDevice)
200 #undef INSTANTIATE
201 
202 #endif  // TENSORFLOW_USE_SYCL
203 
204 }  // namespace tensorflow
205