1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
17 #define TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
18
19 #include <numeric>
20 #include <string>
21 #include <vector>
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_types.h"
24 #include "tensorflow/core/platform/logging.h"
25
26 namespace tensorflow {
27 // Transpose tensor 'in' into tensor 'out' according to dimension
28 // permutation 'perm'.
29 //
30 // REQUIRES: in.dtype() == out->dtype()
31 // REQUIRES: in.dims() == out->dims()
32 // REQUIRES: in.dims() == perm.size()
33 // REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
34 template <typename Device>
35 Status DoTranspose(const Device& device, const Tensor& in,
36 const gtl::ArraySlice<int32> perm, Tensor* out);
37
38 // Conjugate and transpose tensor 'in' into tensor 'out' according to dimension
39 // permutation 'perm'.
40 //
41 // REQUIRES: in.dtype() == out->dtype()
42 // REQUIRES: in.dims() == out->dims()
43 // REQUIRES: in.dims() == perm.size()
44 // REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
45 template <typename Device>
46 Status DoConjugateTranspose(const Device& device, const Tensor& in,
47 const gtl::ArraySlice<int32> perm, Tensor* out);
48
49 // Convenience versions of DoTranspose that only swap the last (inner) two
50 // dimensions.
51 template <typename Device>
52 Status DoMatrixTranspose(const Device& device, const Tensor& in, Tensor* out);
53
54 // Convenience versions of DoConjugateTranspose that only swap the last (inner)
55 // two dimensions.
56 template <typename Device>
57 Status DoConjugateMatrixTranspose(const Device& device, const Tensor& in,
58 Tensor* out);
59
60 // Primary device specific functor to be specialized for each device and type.
61 template <typename Device, typename T, bool conjugate = false>
62 struct Transpose {
63 static void run(const Device& d, const Tensor& in,
64 const gtl::ArraySlice<int32> perm, Tensor* out);
65 };
66
67 // Implementation details.
68 namespace internal {
69
70 typedef gtl::InlinedVector<int64, 8> TransposeDimsVec;
71 typedef gtl::InlinedVector<int32, 8> TransposePermsVec;
72
73 // Helper function that takes a tensor shape, a permutation, combines the
74 // neighboring shapes if their indices in the permutation are consecutive.
75 // The function outputs the combined shape and new permutation.
76 // Example: Tensor shape {2, 3, 4, 5, 120} and permutation {0, 4, 1, 2, 3} will
77 // produce new shape {2, 60, 120} and new permutation {0, 2, 1}.
ReduceTransposeDimensions(const TensorShape & shape,gtl::ArraySlice<int32> perm,TransposePermsVec * new_perm,TransposeDimsVec * new_dims)78 inline void ReduceTransposeDimensions(const TensorShape& shape,
79 gtl::ArraySlice<int32> perm,
80 TransposePermsVec* new_perm,
81 TransposeDimsVec* new_dims) {
82 CHECK_EQ(shape.dims(), perm.size());
83 if (shape.dims() == 1) {
84 // If input dimension is already 1, no need to reduce dimension.
85 new_perm->resize(1);
86 (*new_perm)[0] = perm[0];
87 (*new_dims)[0] = shape.dim_size(0);
88 return;
89 }
90 TransposePermsVec new_dim_position(shape.dims(), -1);
91 TransposeDimsVec combined_dims(shape.dims(), 0);
92 int cur_head = perm[0];
93 new_dim_position[cur_head] = 0;
94 combined_dims[0] = shape.dim_size(cur_head);
95 int dim_idx = 0;
96 for (int perm_idx = 1; perm_idx < shape.dims(); ++perm_idx) {
97 // If two indices in permutation are consecutive numbers, combine their
98 // dimensions.
99 if (cur_head + 1 == perm[perm_idx]) {
100 cur_head = perm[perm_idx];
101 combined_dims[dim_idx] *= shape.dim_size(cur_head);
102 } else {
103 // Else start a new dimension.
104 cur_head = perm[perm_idx];
105 dim_idx++;
106 new_dim_position[cur_head] = dim_idx;
107 combined_dims[dim_idx] = shape.dim_size(cur_head);
108 }
109 }
110 // Compact the new permutations and dimension sizes.
111 new_perm->resize(dim_idx + 1);
112 new_dims->resize(dim_idx + 1);
113 dim_idx = 0;
114 for (int i = 0; i < new_dim_position.size(); ++i) {
115 if (new_dim_position[i] >= 0) {
116 int new_perm_idx = new_dim_position[i];
117 (*new_perm)[dim_idx] = new_perm_idx;
118 (*new_dims)[dim_idx] = combined_dims[new_perm_idx];
119 dim_idx++;
120 }
121 }
122 }
123
124 // If all non-singleton dimensions remain in ascending order, the shuffled
125 // singletons can be transposed by a reshape, saving a memory allocation & copy.
126 // |permutation| must be a permutation of {0, .., input_shape.dims() - 1}.
127 // That is, for all i, 0 <= perm[i] < input_shape.dims().
128 // In practice, this is checked in TransposeOp::Compute prior to calling this
129 // function, and the function sits here to facilitate unit testing.
NonSingletonDimensionsAlign(const TensorShape & input_shape,const std::vector<int32> & permutation)130 inline bool NonSingletonDimensionsAlign(const TensorShape& input_shape,
131 const std::vector<int32>& permutation) {
132 int last_nonsingleton_perm_dim = -1;
133 for (int perm_dim : permutation) {
134 if (input_shape.dim_size(perm_dim) == 1) {
135 continue;
136 }
137 if (perm_dim < last_nonsingleton_perm_dim) {
138 return false;
139 }
140 last_nonsingleton_perm_dim = perm_dim;
141 }
142 return true;
143 }
144
145 // Uses Eigen to transpose.
146 template <typename Device, typename T, int NDIMS>
TransposeUsingEigen(const Device & d,const Tensor & in,const gtl::ArraySlice<int32> perm,bool conjugate,Tensor * out)147 void TransposeUsingEigen(const Device& d, const Tensor& in,
148 const gtl::ArraySlice<int32> perm, bool conjugate,
149 Tensor* out) {
150 Eigen::array<int, NDIMS> p;
151 for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
152 auto x = typename TTypes<T, NDIMS>::ConstTensor(
153 reinterpret_cast<const T*>(in.tensor_data().data()),
154 in.shape().AsEigenDSizes<NDIMS>());
155 auto y = typename TTypes<T, NDIMS>::Tensor(
156 reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())),
157 out->shape().AsEigenDSizes<NDIMS>());
158 if (conjugate) {
159 y.device(d) = x.conjugate().shuffle(p);
160 } else {
161 y.device(d) = x.shuffle(p);
162 }
163 }
164
165 template <typename Device>
DoTransposeImpl(const Device & d,const Tensor & in,const gtl::ArraySlice<int32> perm,bool conjugate,Tensor * out)166 Status DoTransposeImpl(const Device& d, const Tensor& in,
167 const gtl::ArraySlice<int32> perm, bool conjugate,
168 Tensor* out) {
169 CHECK_GE(in.dims(), 2);
170 CHECK_EQ(in.dims(), out->dims());
171 CHECK_EQ(in.dims(), perm.size());
172 CHECK_EQ(in.dtype(), out->dtype());
173 switch (in.dtype()) {
174 case DT_BOOL:
175 case DT_INT8:
176 case DT_QINT8:
177 case DT_QUINT8:
178 case DT_UINT8:
179 Transpose<Device, uint8>::run(d, in, perm, out);
180 break;
181
182 case DT_BFLOAT16:
183 case DT_HALF:
184 case DT_INT16:
185 case DT_QINT16:
186 case DT_QUINT16:
187 case DT_UINT16:
188 Transpose<Device, uint16>::run(d, in, perm, out);
189 break;
190
191 case DT_FLOAT:
192 case DT_INT32:
193 case DT_QINT32:
194 Transpose<Device, uint32>::run(d, in, perm, out);
195 break;
196
197 case DT_DOUBLE:
198 case DT_INT64:
199 Transpose<Device, uint64>::run(d, in, perm, out);
200 break;
201
202 case DT_COMPLEX64:
203 if (conjugate) {
204 #if defined(__ANDROID__) and !defined(__clang__)
205 // Workaround for GCC compiler bug in Android toolchain.
206 return errors::Unimplemented(
207 "Conjugate transpose of complex64 not supported for GCC on "
208 "Android.");
209 #else
210 Transpose<Device, complex64, /*conjugate=*/true>::run(d, in, perm, out);
211 #endif
212 } else {
213 Transpose<Device, uint64>::run(d, in, perm, out);
214 }
215 break;
216
217 case DT_COMPLEX128:
218 if (conjugate) {
219 Transpose<Device, complex128, /*conjugate=*/true>::run(d, in, perm,
220 out);
221 } else {
222 Transpose<Device, complex128, /*conjugate=*/false>::run(d, in, perm,
223 out);
224 }
225 break;
226
227 case DT_STRING:
228 Transpose<Device, string>::run(d, in, perm, out);
229 break;
230
231 default:
232 return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype());
233 }
234 return Status::OK();
235 }
236
237 template <typename Device>
DoMatrixTransposeImpl(const Device & device,const Tensor & in,bool conjugate,Tensor * out)238 inline Status DoMatrixTransposeImpl(const Device& device, const Tensor& in,
239 bool conjugate, Tensor* out) {
240 const int ndims = in.dims();
241 if (ndims == 0) return Status::OK();
242 TransposePermsVec perm(ndims);
243 std::iota(perm.begin(), perm.end(), 0);
244 std::swap(perm[ndims - 2], perm[ndims - 1]);
245 return DoTransposeImpl(device, in, perm, conjugate, out);
246 }
247
248 #ifdef TENSORFLOW_USE_SYCL
249 // For SYCL lets always go through Eigen
250 template <typename Device, typename T>
251 void TransposeSYCL(const Device& d, const Tensor& in,
252 const gtl::ArraySlice<int32> perm, bool conjugate,
253 Tensor* out);
254 #endif // TENSORFLOW_USE_SYCL
255
256 } // namespace internal
257 } // namespace tensorflow
258
259 #endif // TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
260