• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/transpose.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/inlined_vector.h"
25 #include "absl/numeric/int128.h"
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/compiler/xla/array.h"
28 #include "tensorflow/compiler/xla/permutation_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33 #include "tensorflow/core/platform/threadpool.h"
34 #include "tensorflow/core/protobuf/error_codes.pb.h"
35 
36 namespace xla {
37 
38 class TestTransposePlan : public TransposePlan {
39  public:
40   using TransposePlan::CoalesceDimensions;
41   using TransposePlan::RemoveTrivialDimensions;
42 };
43 
TEST(TransposeTest,RemoveTrivialDimensions)44 TEST(TransposeTest, RemoveTrivialDimensions) {
45   absl::InlinedVector<int64_t, 4> dims = {4, 5, 1, 3, 1, 2, 5};
46   absl::InlinedVector<int64_t, 4> perm = {0, 2, 1, 4, 3, 6, 5};
47   absl::InlinedVector<int64_t, 4> lda = {2, 5, 7, 100, 3, 0, 1};
48   absl::InlinedVector<int64_t, 4> lda_tile = {1, 1, 1, 1, 1, 1, 1};
49   absl::InlinedVector<int64_t, 4> input_tiling = {1, 1, 1, 1, 1, 1, 1};
50   absl::InlinedVector<int64_t, 4> output_tiling = {1, 1, 1, 1, 1, 1, 1};
51   TestTransposePlan::RemoveTrivialDimensions(dims, perm, lda, lda_tile,
52                                              input_tiling, output_tiling);
53   EXPECT_THAT(dims, testing::ElementsAre(4, 5, 3, 2, 5));
54   EXPECT_THAT(perm, testing::ElementsAre(0, 1, 2, 4, 3));
55 
56   dims = {4, 5, 3, 2, 5};
57   perm = {4, 3, 2, 1, 0};
58   lda = {2, 5, 100, 0, 1};
59   lda_tile = {1, 1, 1, 1, 1};
60   input_tiling = {1, 1, 1, 1, 1};
61   output_tiling = {1, 1, 1, 1, 1};
62   TestTransposePlan::RemoveTrivialDimensions(dims, perm, lda, lda_tile,
63                                              input_tiling, output_tiling);
64   EXPECT_THAT(dims, testing::ElementsAre(4, 5, 3, 2, 5));
65   EXPECT_THAT(perm, testing::ElementsAre(4, 3, 2, 1, 0));
66 }
67 
TEST(TransposeTest,CoalesceDimensions)68 TEST(TransposeTest, CoalesceDimensions) {
69   absl::InlinedVector<int64_t, 4> dims = {4, 5, 1, 3, 1, 2, 5};
70   absl::InlinedVector<int64_t, 4> perm = {0, 2, 1, 4, 3, 6, 5};
71   absl::InlinedVector<int64_t, 4> lda = {50, 30, 30, 10, 10, 5, 1};
72   absl::InlinedVector<int64_t, 4> lda_tile = {1, 1, 1, 1, 1, 1, 1};
73   absl::InlinedVector<int64_t, 4> input_tiling = {1, 1, 1, 1, 1, 1, 1};
74   absl::InlinedVector<int64_t, 4> output_tiling = {1, 1, 1, 1, 1, 1, 1};
75   TestTransposePlan::CoalesceDimensions(dims, perm, lda, lda_tile, input_tiling,
76                                         output_tiling);
77   EXPECT_THAT(dims, testing::ElementsAre(4, 5, 1, 3, 1, 2, 5));
78   EXPECT_THAT(perm, testing::ElementsAre(0, 2, 1, 4, 3, 6, 5));
79   EXPECT_THAT(lda, testing::ElementsAre(50, 30, 30, 10, 10, 5, 1));
80 
81   dims = {4, 5, 3, 2, 5};
82   perm = {4, 1, 2, 3, 0};
83   lda = {150, 30, 10, 5, 1};
84   lda_tile = {1, 1, 1, 1, 1};
85   input_tiling = {1, 1, 1, 1, 1};
86   output_tiling = {1, 1, 1, 1, 1};
87   TestTransposePlan::CoalesceDimensions(dims, perm, lda, lda_tile, input_tiling,
88                                         output_tiling);
89   EXPECT_THAT(dims, testing::ElementsAre(4, 30, 5));
90   EXPECT_THAT(perm, testing::ElementsAre(2, 1, 0));
91   EXPECT_THAT(lda, testing::ElementsAre(150, 5, 1));
92 
93   dims = {4, 5, 3, 2, 5};
94   perm = {0, 1, 2, 3, 4};
95   lda = {150, 30, 10, 5, 1};
96   lda_tile = {1, 1, 1, 1, 1};
97   input_tiling = {1, 1, 1, 1, 1};
98   output_tiling = {1, 1, 1, 1, 1};
99   TestTransposePlan::CoalesceDimensions(dims, perm, lda, lda_tile, input_tiling,
100                                         output_tiling);
101   EXPECT_THAT(dims, testing::ElementsAre(600));
102   EXPECT_THAT(perm, testing::ElementsAre(0));
103   EXPECT_THAT(lda, testing::ElementsAre(1));
104 
105   dims = {4, 5, 3, 2, 5};
106   perm = {4, 1, 2, 3, 0};
107   lda = {150, 30, 10, 7, 1};  // Non-standard stridings prevent coalescing.
108   lda_tile = {1, 1, 1, 1, 1};
109   input_tiling = {1, 1, 1, 1, 1};
110   output_tiling = {1, 1, 1, 1, 1};
111   TestTransposePlan::CoalesceDimensions(dims, perm, lda, lda_tile, input_tiling,
112                                         output_tiling);
113   EXPECT_THAT(dims, testing::ElementsAre(4, 15, 2, 5));
114   EXPECT_THAT(perm, testing::ElementsAre(3, 1, 2, 0));
115   EXPECT_THAT(lda, testing::ElementsAre(150, 10, 7, 1));
116 }
117 
TEST(TransposeTest,InvalidTilings)118 TEST(TransposeTest, InvalidTilings) {
119   auto plan =
120       TransposePlan::Create(sizeof(float), {3, 4, 5}, {0, 1, 2},
121                             /*input_layout=*/TransposePlan::Tiling{{8, 128}},
122                             /*output_tiling=*/TransposePlan::Tiling{{4}});
123   EXPECT_EQ(plan.status().code(), tensorflow::error::UNIMPLEMENTED);
124   EXPECT_THAT(
125       plan.status().error_message(),
126       testing::HasSubstr(
127           "Only one of the input and output may have a non-trivial tiling"));
128 }
129 
130 // Computes the size in elements of a tiled array.
SizeOfTiledArray(absl::Span<int64_t const> shape,absl::Span<int64_t const> tiling)131 int64_t SizeOfTiledArray(absl::Span<int64_t const> shape,
132                          absl::Span<int64_t const> tiling) {
133   int64_t size = 1;
134   for (size_t i = 0; i < shape.size(); ++i) {
135     if (i >= shape.size() - tiling.size()) {
136       size *= RoundUpTo(shape[i], tiling[i - (shape.size() - tiling.size())]);
137     } else {
138       size *= shape[i];
139     }
140   }
141   return size;
142 }
143 
144 // Advances 'indices' in the lexicographical order of the multidimensional
145 // array with `shape`. Returns false if the end of the array has been reached.
BumpIndices(absl::Span<int64_t const> shape,absl::Span<int64_t> indices)146 bool BumpIndices(absl::Span<int64_t const> shape, absl::Span<int64_t> indices) {
147   CHECK_EQ(shape.size(), indices.size());
148   for (int dimno = indices.size() - 1; dimno >= 0; --dimno) {
149     if (indices[dimno] + 1 < shape[dimno]) {
150       indices[dimno]++;
151       // Whenever an index of a dimension is increased, it means that all
152       // following dimensions have maxed out, so they must go to 0.
153       std::fill(indices.begin() + dimno + 1, indices.end(), 0);
154       return true;
155     }
156   }
157   return false;
158 }
159 
160 // Converts a multidimensional index `indices` into an array with `shape` and
161 // tiling `tiling` into a linear offset into a buffer.
IndexToLinearIndex(absl::Span<int64_t const> shape,absl::Span<int64_t const> tiling,absl::Span<int64_t const> indices)162 int64_t IndexToLinearIndex(absl::Span<int64_t const> shape,
163                            absl::Span<int64_t const> tiling,
164                            absl::Span<int64_t const> indices) {
165   CHECK_LE(tiling.size(), shape.size());
166   CHECK_EQ(shape.size(), indices.size());
167   int64_t stride = 1;
168   int64_t offset = 0;
169 
170   auto index_it = indices.rbegin();
171   auto tile_it = tiling.rbegin();
172   for (; tile_it != tiling.rend(); ++index_it, ++tile_it) {
173     offset += (*index_it % *tile_it) * stride;
174     stride *= *tile_it;
175   }
176   index_it = indices.rbegin();
177   tile_it = tiling.rbegin();
178   auto shape_it = shape.rbegin();
179   for (; tile_it != tiling.rend(); ++index_it, ++shape_it, ++tile_it) {
180     offset += (*index_it / *tile_it) * stride;
181     stride *= CeilOfRatio(*shape_it, *tile_it);
182   }
183   for (; shape_it != shape.rend(); ++index_it, ++shape_it) {
184     offset += *index_it * stride;
185     stride *= *shape_it;
186   }
187   return offset;
188 }
189 
190 // Slow reference code that converts an array from an untiled layout into a
191 // tiled layout.
192 template <typename T>
TileArray(const Array<T> & in,absl::Span<int64_t const> tiling)193 std::vector<T> TileArray(const Array<T>& in, absl::Span<int64_t const> tiling) {
194   std::vector<T> out(SizeOfTiledArray(in.dimensions(), tiling), -1);
195   if (in.num_elements() == 0) {
196     return out;
197   }
198   std::vector<int64_t> indices(in.num_dimensions(), 0);
199   do {
200     int64_t i = IndexToLinearIndex(in.dimensions(), tiling, indices);
201     out.at(i) = in(indices);
202   } while (BumpIndices(in.dimensions(), absl::MakeSpan(indices)));
203   return out;
204 }
205 
206 // Reference implementation: transpose using Eigen.
207 template <typename T, int NDIMS>
TransposeUsingEigenNd(const T * input,T * output,absl::Span<int64_t const> dims,absl::Span<int64_t const> dims_out,absl::Span<int64_t const> permutation)208 void TransposeUsingEigenNd(const T* input, T* output,
209                            absl::Span<int64_t const> dims,
210                            absl::Span<int64_t const> dims_out,
211                            absl::Span<int64_t const> permutation) {
212   typedef Eigen::TensorMap<
213       Eigen::Tensor<T, NDIMS, Eigen::RowMajor, Eigen::DenseIndex>,
214       Eigen::Aligned>
215       Tensor;
216   typedef Eigen::TensorMap<
217       Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, Eigen::DenseIndex>,
218       Eigen::Aligned>
219       ConstTensor;
220 
221   Eigen::array<int, NDIMS> p;
222   Eigen::DSizes<Eigen::DenseIndex, NDIMS> dims_eigen;
223   Eigen::DSizes<Eigen::DenseIndex, NDIMS> dims_out_eigen;
224   for (int i = 0; i < NDIMS; ++i) {
225     p[i] = permutation[i];
226     dims_eigen[i] = dims[i];
227     dims_out_eigen[i] = dims_out[i];
228   }
229   auto x = ConstTensor(input, dims_eigen);
230   auto y = Tensor(output, dims_out_eigen);
231   y = x.shuffle(p);
232 }
233 
234 template <typename T>
TransposeUsingEigen(const T * input,T * output,absl::Span<int64_t const> dims,absl::Span<int64_t const> dims_out,absl::Span<int64_t const> permutation)235 void TransposeUsingEigen(const T* input, T* output,
236                          absl::Span<int64_t const> dims,
237                          absl::Span<int64_t const> dims_out,
238                          absl::Span<int64_t const> permutation) {
239   switch (dims.size()) {
240     case 0:
241       return;
242     case 1:
243       TransposeUsingEigenNd<T, 1>(input, output, dims, dims_out, permutation);
244       return;
245     case 2:
246       TransposeUsingEigenNd<T, 2>(input, output, dims, dims_out, permutation);
247       return;
248     case 3:
249       TransposeUsingEigenNd<T, 3>(input, output, dims, dims_out, permutation);
250       return;
251     case 4:
252       TransposeUsingEigenNd<T, 4>(input, output, dims, dims_out, permutation);
253       return;
254     default:
255       LOG(FATAL) << "Unimplemented Eigen transpose rank";
256   }
257 }
258 
259 struct TransposeTestCase {
TransposeTestCasexla::TransposeTestCase260   TransposeTestCase(std::vector<int64_t> dims, std::vector<int64_t> permutation,
261                     std::vector<int64_t> input_tiling = {},
262                     std::vector<int64_t> output_tiling = {})
263       : dims(std::move(dims)),
264         permutation(std::move(permutation)),
265         input_tiling(std::move(input_tiling)),
266         output_tiling(std::move(output_tiling)) {}
267 
268   std::vector<int64_t> dims;
269   std::vector<int64_t> permutation;
270   std::vector<int64_t> input_tiling;
271   std::vector<int64_t> output_tiling;
272 
ToStringxla::TransposeTestCase273   std::string ToString() const {
274     return absl::StrFormat(
275         "[%s],perm=[%s],tiling=[%s]/[%s]", absl::StrJoin(dims, ","),
276         absl::StrJoin(permutation, ","), absl::StrJoin(input_tiling, ","),
277         absl::StrJoin(output_tiling, ","));
278   }
279 };
280 
operator <<(std::ostream & os,const TransposeTestCase & test)281 std::ostream& operator<<(std::ostream& os, const TransposeTestCase& test) {
282   os << test.ToString();
283   return os;
284 }
285 
GetTransposeTestCases()286 std::vector<TransposeTestCase> GetTransposeTestCases() {
287   std::vector<TransposeTestCase> cases = {
288       TransposeTestCase(/*dims=*/{1}, /*permutation=*/{0}),
289       TransposeTestCase(/*dims=*/{4}, /*permutation=*/{0}),
290       TransposeTestCase(/*dims=*/{27}, /*permutation=*/{0}),
291       TransposeTestCase(/*dims=*/{1, 1}, /*permutation=*/{0, 1}),
292       TransposeTestCase(/*dims=*/{1, 1}, /*permutation=*/{1, 0}),
293       TransposeTestCase(/*dims=*/{2, 2}, /*permutation=*/{0, 1}),
294       TransposeTestCase(/*dims=*/{4, 4}, /*permutation=*/{1, 0}),
295       TransposeTestCase(/*dims=*/{4, 4}, /*permutation=*/{0, 1}),
296       TransposeTestCase(/*dims=*/{4, 4}, /*permutation=*/{1, 0}),
297       TransposeTestCase(/*dims=*/{8, 8}, /*permutation=*/{0, 1}),
298       TransposeTestCase(/*dims=*/{8, 8}, /*permutation=*/{1, 0}),
299       TransposeTestCase(/*dims=*/{16, 16}, /*permutation=*/{0, 1}),
300       TransposeTestCase(/*dims=*/{16, 16}, /*permutation=*/{1, 0}),
301       TransposeTestCase(/*dims=*/{11, 15}, /*permutation=*/{0, 1}),
302       TransposeTestCase(/*dims=*/{11, 15}, /*permutation=*/{1, 0}),
303       TransposeTestCase(/*dims=*/{11, 15, 13}, /*permutation=*/{0, 1, 2}),
304       TransposeTestCase(/*dims=*/{11, 15, 13}, /*permutation=*/{0, 2, 1}),
305       TransposeTestCase(/*dims=*/{11, 15, 13}, /*permutation=*/{1, 2, 0}),
306       TransposeTestCase(/*dims=*/{11, 15, 13}, /*permutation=*/{1, 0, 2}),
307       TransposeTestCase(/*dims=*/{11, 15, 13}, /*permutation=*/{2, 0, 1}),
308       TransposeTestCase(/*dims=*/{64, 64, 64}, /*permutation=*/{2, 1, 0}),
309       TransposeTestCase(/*dims=*/{256, 256, 256}, /*permutation=*/{2, 1, 0}),
310       TransposeTestCase(/*dims=*/{4, 8, 16, 32}, /*permutation=*/{3, 1, 0, 2}),
311       TransposeTestCase(/*dims=*/{64, 224, 224, 3},
312                         /*permutation=*/{3, 1, 2, 0}),
313 
314       TransposeTestCase(/*dims=*/{3}, /*permutation=*/{0},
315                         /*input_tiling=*/{3}),
316       TransposeTestCase(/*dims=*/{3}, /*permutation=*/{0},
317                         /*input_tiling=*/{},
318                         /*output_tiling=*/{3}),
319       TransposeTestCase(/*dims=*/{2, 4, 6}, /*permutation=*/{0, 1, 2},
320                         /*input_tiling=*/{},
321                         /*output_tiling=*/{2, 3}),
322       TransposeTestCase(/*dims=*/{4}, /*permutation=*/{0},
323                         /*input_tiling=*/{3}),
324       TransposeTestCase(/*dims=*/{5}, /*permutation=*/{0},
325                         /*input_tiling=*/{},
326                         /*output_tiling=*/{3}),
327       TransposeTestCase(/*dims=*/{8}, /*permutation=*/{0},
328                         /*input_tiling=*/{},
329                         /*output_tiling=*/{3}),
330       TransposeTestCase(/*dims=*/{8}, /*permutation=*/{0},
331                         /*input_tiling=*/{3},
332                         /*output_tiling=*/{}),
333       TransposeTestCase(/*dims=*/{29}, /*permutation=*/{0},
334                         /*input_tiling=*/{},
335                         /*output_tiling=*/{3}),
336       TransposeTestCase(/*dims=*/{12, 7}, /*permutation=*/{1, 0},
337                         /*input_tiling=*/{4}),
338       TransposeTestCase(/*dims=*/{12, 7}, /*permutation=*/{1, 0},
339                         /*input_tiling=*/{}, /*output_tiling=*/{5}),
340       TransposeTestCase(/*dims=*/{12, 7}, /*permutation=*/{1, 0},
341                         /*input_tiling=*/{2, 4}),
342       TransposeTestCase(/*dims=*/{12, 7}, /*permutation=*/{1, 0},
343                         /*input_tiling=*/{}, /*output_tiling=*/{5, 2}),
344       TransposeTestCase(/*dims=*/{128, 224, 224, 3},
345                         /*permutation=*/{3, 1, 2, 0},
346                         /*input_tiling=*/{},
347                         /*output_tiling=*/{8, 128}),
348   };
349   return cases;
350 }
351 
352 class TransposeTest : public ::testing::TestWithParam<TransposeTestCase> {
353  protected:
354   template <typename T>
TestTranspose(int parallelism)355   void TestTranspose(int parallelism) {
356     const TransposeTestCase test = GetParam();
357     tensorflow::thread::ThreadPool threadpool(tensorflow::Env::Default(),
358                                               "Transpose", parallelism);
359     std::vector<int64_t> output_dims = Permute(test.dims, test.permutation);
360     TF_ASSERT_OK_AND_ASSIGN(
361         auto plan, TransposePlan::Create(
362                        sizeof(T), test.dims, test.permutation,
363                        TransposePlan::Tiling{test.input_tiling},
364                        TransposePlan::Tiling{test.output_tiling},
365                        TransposePlan::Transformation::kNone, parallelism));
366     VLOG(1) << plan->ToString();
367     xla::Array<T> untiled_input(test.dims);
368     untiled_input.FillIota(0);
369     xla::Array<T> expected_untiled_output(output_dims);
370     TransposeUsingEigen(untiled_input.data(), expected_untiled_output.data(),
371                         test.dims, output_dims, test.permutation);
372 
373     auto tiled_input = TileArray(untiled_input, test.input_tiling);
374     auto expected_tiled_output =
375         TileArray(expected_untiled_output, test.output_tiling);
376 
377     std::vector<T> output(
378         SizeOfTiledArray(plan->OutputDims(), test.output_tiling), -1);
379     plan->Execute(
380         tiled_input.data(), output.data(),
381         [&](std::function<void()> fn) { threadpool.Schedule(std::move(fn)); });
382 
383     EXPECT_EQ(expected_tiled_output, output);
384   }
385 };
386 
TEST_P(TransposeTest,TransposeInt8)387 TEST_P(TransposeTest, TransposeInt8) { TestTranspose<int8_t>(1); }
TEST_P(TransposeTest,TransposeInt16)388 TEST_P(TransposeTest, TransposeInt16) { TestTranspose<int16_t>(1); }
TEST_P(TransposeTest,TransposeInt32)389 TEST_P(TransposeTest, TransposeInt32) { TestTranspose<int32_t>(1); }
TEST_P(TransposeTest,TransposeInt64)390 TEST_P(TransposeTest, TransposeInt64) { TestTranspose<int64_t>(1); }
TEST_P(TransposeTest,TransposeInt128)391 TEST_P(TransposeTest, TransposeInt128) { TestTranspose<absl::int128>(1); }
392 
TEST_P(TransposeTest,ParallelTransposeInt8)393 TEST_P(TransposeTest, ParallelTransposeInt8) { TestTranspose<int8_t>(16); }
TEST_P(TransposeTest,ParallelTransposeInt32)394 TEST_P(TransposeTest, ParallelTransposeInt32) { TestTranspose<int32_t>(16); }
395 
396 INSTANTIATE_TEST_SUITE_P(TransposeTestInstance, TransposeTest,
397                          ::testing::ValuesIn(GetTransposeTestCases()));
398 
TEST(TransposeTest,NegativeStrides1D)399 TEST(TransposeTest, NegativeStrides1D) {
400   int64_t n = 10;
401   std::vector<int32_t> input(n);
402   std::vector<int32_t> output(n);
403   std::vector<int32_t> expected(n);
404   absl::c_iota(input, int32_t{7});
405   std::iota(expected.rbegin(), expected.rend(), 7);
406   TF_ASSERT_OK_AND_ASSIGN(
407       auto plan, TransposePlan::Create(
408                      sizeof(int32_t), {n}, /*permutation=*/{0},
409                      TransposePlan::Striding{{-int64_t{sizeof(int32_t)}}}));
410   plan->Execute(input.data() + (n - 1), output.data());
411   EXPECT_EQ(expected, output);
412 }
413 
TEST(TransposeTest,NegativeStrides2D)414 TEST(TransposeTest, NegativeStrides2D) {
415   xla::Array<int16_t> input = {
416       {1, 2, 3, 4},
417       {5, 6, 7, 8},
418       {9, 10, 11, 12},
419   };
420   xla::Array<int16_t> expected = {
421       {4, 8, 12},
422       {3, 7, 11},
423       {2, 6, 10},
424       {1, 5, 9},
425   };
426   xla::Array<int16_t> output({4, 3});
427   TF_ASSERT_OK_AND_ASSIGN(
428       auto plan, TransposePlan::Create(
429                      sizeof(int16_t), {3, 4}, /*permutation=*/{1, 0},
430                      TransposePlan::Striding{
431                          {4 * sizeof(int16_t), -int64_t{sizeof(int16_t)}}}));
432   plan->Execute(input.data() + 3, output.data());
433   EXPECT_EQ(expected, output);
434 }
435 
BenchmarkCases()436 static std::vector<TransposeTestCase> BenchmarkCases() {
437   return std::vector<TransposeTestCase>{
438       TransposeTestCase(/*dims=*/{256, 256},
439                         /*permutation=*/{1, 0}),
440       TransposeTestCase(/*dims=*/{512, 512},
441                         /*permutation=*/{1, 0}),
442       TransposeTestCase(/*dims=*/{1024, 1024},
443                         /*permutation=*/{1, 0}),
444       TransposeTestCase(/*dims=*/{256, 256, 256},
445                         /*permutation=*/{0, 2, 1}),
446       TransposeTestCase(/*dims=*/{256, 256, 256},
447                         /*permutation=*/{1, 0, 2}),
448       TransposeTestCase(/*dims=*/{256, 256, 256},
449                         /*permutation=*/{1, 2, 0}),
450       TransposeTestCase(/*dims=*/{256, 256, 256},
451                         /*permutation=*/{2, 0, 1}),
452       TransposeTestCase(/*dims=*/{256, 256, 256},
453                         /*permutation=*/{2, 1, 0}),
454       TransposeTestCase(/*dims=*/{512, 512, 512},
455                         /*permutation=*/{0, 2, 1}),
456       TransposeTestCase(/*dims=*/{512, 512, 512},
457                         /*permutation=*/{1, 0, 2}),
458       TransposeTestCase(/*dims=*/{512, 512, 512},
459                         /*permutation=*/{1, 2, 0}),
460       TransposeTestCase(/*dims=*/{512, 512, 512},
461                         /*permutation=*/{2, 0, 1}),
462       TransposeTestCase(/*dims=*/{512, 512, 512},
463                         /*permutation=*/{2, 1, 0}),
464       TransposeTestCase(/*dims=*/{64, 224, 224, 3},
465                         /*permutation=*/{1, 2, 3, 0}),
466       TransposeTestCase(/*dims=*/{256, 64, 64, 3},
467                         /*permutation=*/{1, 3, 2, 0}),
468   };
469 }
470 
471 template <typename T>
BM_Eigen(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)472 void BM_Eigen(const TransposeTestCase& bm, int parallelism,
473               ::testing::benchmark::State& state) {
474   CHECK_EQ(parallelism, 1);
475   Array<T> input(bm.dims);
476   input.FillIota(0);
477   std::vector<int64_t> output_dims = Permute(bm.dims, bm.permutation);
478   Array<T> output(output_dims);
479   for (auto s : state) {
480     TransposeUsingEigen(input.data(), output.data(), bm.dims, output_dims,
481                         bm.permutation);
482     tensorflow::testing::DoNotOptimize(output);
483   }
484 }
BM_Eigen_uint8(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)485 static void BM_Eigen_uint8(const TransposeTestCase& bm, int parallelism,
486                            ::testing::benchmark::State& state) {
487   BM_Eigen<uint8_t>(std::move(bm), parallelism, state);
488 }
BM_Eigen_float(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)489 static void BM_Eigen_float(const TransposeTestCase& bm, int parallelism,
490                            ::testing::benchmark::State& state) {
491   BM_Eigen<float>(bm, parallelism, state);
492 }
493 
494 template <typename T>
BM_Transpose(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)495 void BM_Transpose(const TransposeTestCase& bm, int parallelism,
496                   ::testing::benchmark::State& state) {
497   TF_ASSERT_OK_AND_ASSIGN(
498       auto plan,
499       TransposePlan::Create(sizeof(T), bm.dims, bm.permutation,
500                             TransposePlan::Tiling{}, TransposePlan::Tiling{},
501                             TransposePlan::Transformation::kNone, parallelism));
502   Array<T> input(bm.dims);
503   input.FillIota(0);
504   std::vector<int64_t> output_dims = Permute(bm.dims, bm.permutation);
505   Array<T> output(output_dims);
506   tensorflow::thread::ThreadPool threadpool(tensorflow::Env::Default(),
507                                             "Transpose", parallelism);
508   for (auto s : state) {
509     plan->Execute(input.data(), output.data(), [&](std::function<void()> fn) {
510       threadpool.Schedule(std::move(fn));
511     });
512     tensorflow::testing::DoNotOptimize(output);
513   }
514 }
BM_Transpose_uint8(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)515 static void BM_Transpose_uint8(const TransposeTestCase& bm, int parallelism,
516                                ::testing::benchmark::State& state) {
517   BM_Transpose<uint8_t>(bm, parallelism, state);
518 }
BM_Transpose_float(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)519 static void BM_Transpose_float(const TransposeTestCase& bm, int parallelism,
520                                ::testing::benchmark::State& state) {
521   BM_Transpose<float>(bm, parallelism, state);
522 }
523 
__anonb8a511080302() 524 static void* benchmarks = []() {
525   using BenchmarkFn =
526       void (*)(const TransposeTestCase&, int, testing::benchmark::State&);
527   std::vector<std::tuple<std::string, BenchmarkFn, std::vector<int>>> variants =
528       {
529           {"BM_Eigen_uint8", BM_Eigen_uint8, {1}},
530           {"BM_Transpose_uint8", BM_Transpose_uint8, {1, 4, 8}},  //
531           {"BM_Eigen_float", BM_Eigen_float, {1}},
532           {"BM_Transpose_float", BM_Transpose_float, {1, 4, 8}},  //
533   };
534   auto benchmark_cases = BenchmarkCases();
535   for (const auto& benchmark_case : benchmark_cases) {
536     for (const auto& variant : variants) {
537       for (int num_threads : std::get<2>(variant)) {
538         std::string name =
539             absl::StrCat(std::get<0>(variant), "_threads_", num_threads, "_",
540                          absl::StrJoin(benchmark_case.dims, "_"), "_perm_",
541                          absl::StrJoin(benchmark_case.permutation, "_"));
542 
543         TransposeTestCase testcase = benchmark_case;
544         BenchmarkFn fn = std::get<1>(variant);
545         benchmark::RegisterBenchmark(
546             name.c_str(), [fn, num_threads, testcase](benchmark::State& state) {
547               fn(testcase, num_threads, state);
548             });
549       }
550     }
551   }
552   return nullptr;
553 }();
554 
TEST(TransposePlanCache,Basics)555 TEST(TransposePlanCache, Basics) {
556   TransposePlanCache cache(2);
557   TF_ASSERT_OK_AND_ASSIGN(
558       auto p1, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3},
559                                  /*permutation=*/{2, 1, 0}));
560   TF_ASSERT_OK_AND_ASSIGN(
561       auto p1a, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3},
562                                   /*permutation=*/{2, 1, 0}));
563   EXPECT_TRUE(p1.get() == p1a.get());
564   TF_ASSERT_OK_AND_ASSIGN(
565       auto p2, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3},
566                                  /*permutation=*/{1, 2, 0}));
567   EXPECT_TRUE(p1.get() != p2.get());
568   TF_ASSERT_OK_AND_ASSIGN(
569       auto p3, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3},
570                                  /*permutation=*/{0, 1, 2}));
571   EXPECT_TRUE(p3.get() != p1.get());
572   TF_ASSERT_OK_AND_ASSIGN(
573       auto p1b, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3},
574                                   /*permutation=*/{2, 1, 0}));
575   EXPECT_TRUE(p1.get() != p1b.get());
576 }
577 
578 }  // namespace xla
579