1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/transpose.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/inlined_vector.h"
25 #include "absl/numeric/int128.h"
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/compiler/xla/array.h"
28 #include "tensorflow/compiler/xla/permutation_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33 #include "tensorflow/core/platform/threadpool.h"
34 #include "tensorflow/core/protobuf/error_codes.pb.h"
35
36 namespace xla {
37
38 class TestTransposePlan : public TransposePlan {
39 public:
40 using TransposePlan::CoalesceDimensions;
41 using TransposePlan::RemoveTrivialDimensions;
42 };
43
TEST(TransposeTest,RemoveTrivialDimensions)44 TEST(TransposeTest, RemoveTrivialDimensions) {
45 absl::InlinedVector<int64_t, 4> dims = {4, 5, 1, 3, 1, 2, 5};
46 absl::InlinedVector<int64_t, 4> perm = {0, 2, 1, 4, 3, 6, 5};
47 absl::InlinedVector<int64_t, 4> lda = {2, 5, 7, 100, 3, 0, 1};
48 absl::InlinedVector<int64_t, 4> lda_tile = {1, 1, 1, 1, 1, 1, 1};
49 absl::InlinedVector<int64_t, 4> input_tiling = {1, 1, 1, 1, 1, 1, 1};
50 absl::InlinedVector<int64_t, 4> output_tiling = {1, 1, 1, 1, 1, 1, 1};
51 TestTransposePlan::RemoveTrivialDimensions(dims, perm, lda, lda_tile,
52 input_tiling, output_tiling);
53 EXPECT_THAT(dims, testing::ElementsAre(4, 5, 3, 2, 5));
54 EXPECT_THAT(perm, testing::ElementsAre(0, 1, 2, 4, 3));
55
56 dims = {4, 5, 3, 2, 5};
57 perm = {4, 3, 2, 1, 0};
58 lda = {2, 5, 100, 0, 1};
59 lda_tile = {1, 1, 1, 1, 1};
60 input_tiling = {1, 1, 1, 1, 1};
61 output_tiling = {1, 1, 1, 1, 1};
62 TestTransposePlan::RemoveTrivialDimensions(dims, perm, lda, lda_tile,
63 input_tiling, output_tiling);
64 EXPECT_THAT(dims, testing::ElementsAre(4, 5, 3, 2, 5));
65 EXPECT_THAT(perm, testing::ElementsAre(4, 3, 2, 1, 0));
66 }
67
TEST(TransposeTest,CoalesceDimensions)68 TEST(TransposeTest, CoalesceDimensions) {
69 absl::InlinedVector<int64_t, 4> dims = {4, 5, 1, 3, 1, 2, 5};
70 absl::InlinedVector<int64_t, 4> perm = {0, 2, 1, 4, 3, 6, 5};
71 absl::InlinedVector<int64_t, 4> lda = {50, 30, 30, 10, 10, 5, 1};
72 absl::InlinedVector<int64_t, 4> lda_tile = {1, 1, 1, 1, 1, 1, 1};
73 absl::InlinedVector<int64_t, 4> input_tiling = {1, 1, 1, 1, 1, 1, 1};
74 absl::InlinedVector<int64_t, 4> output_tiling = {1, 1, 1, 1, 1, 1, 1};
75 TestTransposePlan::CoalesceDimensions(dims, perm, lda, lda_tile, input_tiling,
76 output_tiling);
77 EXPECT_THAT(dims, testing::ElementsAre(4, 5, 1, 3, 1, 2, 5));
78 EXPECT_THAT(perm, testing::ElementsAre(0, 2, 1, 4, 3, 6, 5));
79 EXPECT_THAT(lda, testing::ElementsAre(50, 30, 30, 10, 10, 5, 1));
80
81 dims = {4, 5, 3, 2, 5};
82 perm = {4, 1, 2, 3, 0};
83 lda = {150, 30, 10, 5, 1};
84 lda_tile = {1, 1, 1, 1, 1};
85 input_tiling = {1, 1, 1, 1, 1};
86 output_tiling = {1, 1, 1, 1, 1};
87 TestTransposePlan::CoalesceDimensions(dims, perm, lda, lda_tile, input_tiling,
88 output_tiling);
89 EXPECT_THAT(dims, testing::ElementsAre(4, 30, 5));
90 EXPECT_THAT(perm, testing::ElementsAre(2, 1, 0));
91 EXPECT_THAT(lda, testing::ElementsAre(150, 5, 1));
92
93 dims = {4, 5, 3, 2, 5};
94 perm = {0, 1, 2, 3, 4};
95 lda = {150, 30, 10, 5, 1};
96 lda_tile = {1, 1, 1, 1, 1};
97 input_tiling = {1, 1, 1, 1, 1};
98 output_tiling = {1, 1, 1, 1, 1};
99 TestTransposePlan::CoalesceDimensions(dims, perm, lda, lda_tile, input_tiling,
100 output_tiling);
101 EXPECT_THAT(dims, testing::ElementsAre(600));
102 EXPECT_THAT(perm, testing::ElementsAre(0));
103 EXPECT_THAT(lda, testing::ElementsAre(1));
104
105 dims = {4, 5, 3, 2, 5};
106 perm = {4, 1, 2, 3, 0};
107 lda = {150, 30, 10, 7, 1}; // Non-standard stridings prevent coalescing.
108 lda_tile = {1, 1, 1, 1, 1};
109 input_tiling = {1, 1, 1, 1, 1};
110 output_tiling = {1, 1, 1, 1, 1};
111 TestTransposePlan::CoalesceDimensions(dims, perm, lda, lda_tile, input_tiling,
112 output_tiling);
113 EXPECT_THAT(dims, testing::ElementsAre(4, 15, 2, 5));
114 EXPECT_THAT(perm, testing::ElementsAre(3, 1, 2, 0));
115 EXPECT_THAT(lda, testing::ElementsAre(150, 10, 7, 1));
116 }
117
TEST(TransposeTest,InvalidTilings)118 TEST(TransposeTest, InvalidTilings) {
119 auto plan =
120 TransposePlan::Create(sizeof(float), {3, 4, 5}, {0, 1, 2},
121 /*input_layout=*/TransposePlan::Tiling{{8, 128}},
122 /*output_tiling=*/TransposePlan::Tiling{{4}});
123 EXPECT_EQ(plan.status().code(), tensorflow::error::UNIMPLEMENTED);
124 EXPECT_THAT(
125 plan.status().error_message(),
126 testing::HasSubstr(
127 "Only one of the input and output may have a non-trivial tiling"));
128 }
129
130 // Computes the size in elements of a tiled array.
SizeOfTiledArray(absl::Span<int64_t const> shape,absl::Span<int64_t const> tiling)131 int64_t SizeOfTiledArray(absl::Span<int64_t const> shape,
132 absl::Span<int64_t const> tiling) {
133 int64_t size = 1;
134 for (size_t i = 0; i < shape.size(); ++i) {
135 if (i >= shape.size() - tiling.size()) {
136 size *= RoundUpTo(shape[i], tiling[i - (shape.size() - tiling.size())]);
137 } else {
138 size *= shape[i];
139 }
140 }
141 return size;
142 }
143
144 // Advances 'indices' in the lexicographical order of the multidimensional
145 // array with `shape`. Returns false if the end of the array has been reached.
BumpIndices(absl::Span<int64_t const> shape,absl::Span<int64_t> indices)146 bool BumpIndices(absl::Span<int64_t const> shape, absl::Span<int64_t> indices) {
147 CHECK_EQ(shape.size(), indices.size());
148 for (int dimno = indices.size() - 1; dimno >= 0; --dimno) {
149 if (indices[dimno] + 1 < shape[dimno]) {
150 indices[dimno]++;
151 // Whenever an index of a dimension is increased, it means that all
152 // following dimensions have maxed out, so they must go to 0.
153 std::fill(indices.begin() + dimno + 1, indices.end(), 0);
154 return true;
155 }
156 }
157 return false;
158 }
159
160 // Converts a multidimensional index `indices` into an array with `shape` and
161 // tiling `tiling` into a linear offset into a buffer.
IndexToLinearIndex(absl::Span<int64_t const> shape,absl::Span<int64_t const> tiling,absl::Span<int64_t const> indices)162 int64_t IndexToLinearIndex(absl::Span<int64_t const> shape,
163 absl::Span<int64_t const> tiling,
164 absl::Span<int64_t const> indices) {
165 CHECK_LE(tiling.size(), shape.size());
166 CHECK_EQ(shape.size(), indices.size());
167 int64_t stride = 1;
168 int64_t offset = 0;
169
170 auto index_it = indices.rbegin();
171 auto tile_it = tiling.rbegin();
172 for (; tile_it != tiling.rend(); ++index_it, ++tile_it) {
173 offset += (*index_it % *tile_it) * stride;
174 stride *= *tile_it;
175 }
176 index_it = indices.rbegin();
177 tile_it = tiling.rbegin();
178 auto shape_it = shape.rbegin();
179 for (; tile_it != tiling.rend(); ++index_it, ++shape_it, ++tile_it) {
180 offset += (*index_it / *tile_it) * stride;
181 stride *= CeilOfRatio(*shape_it, *tile_it);
182 }
183 for (; shape_it != shape.rend(); ++index_it, ++shape_it) {
184 offset += *index_it * stride;
185 stride *= *shape_it;
186 }
187 return offset;
188 }
189
190 // Slow reference code that converts an array from an untiled layout into a
191 // tiled layout.
192 template <typename T>
TileArray(const Array<T> & in,absl::Span<int64_t const> tiling)193 std::vector<T> TileArray(const Array<T>& in, absl::Span<int64_t const> tiling) {
194 std::vector<T> out(SizeOfTiledArray(in.dimensions(), tiling), -1);
195 if (in.num_elements() == 0) {
196 return out;
197 }
198 std::vector<int64_t> indices(in.num_dimensions(), 0);
199 do {
200 int64_t i = IndexToLinearIndex(in.dimensions(), tiling, indices);
201 out.at(i) = in(indices);
202 } while (BumpIndices(in.dimensions(), absl::MakeSpan(indices)));
203 return out;
204 }
205
206 // Reference implementation: transpose using Eigen.
207 template <typename T, int NDIMS>
TransposeUsingEigenNd(const T * input,T * output,absl::Span<int64_t const> dims,absl::Span<int64_t const> dims_out,absl::Span<int64_t const> permutation)208 void TransposeUsingEigenNd(const T* input, T* output,
209 absl::Span<int64_t const> dims,
210 absl::Span<int64_t const> dims_out,
211 absl::Span<int64_t const> permutation) {
212 typedef Eigen::TensorMap<
213 Eigen::Tensor<T, NDIMS, Eigen::RowMajor, Eigen::DenseIndex>,
214 Eigen::Aligned>
215 Tensor;
216 typedef Eigen::TensorMap<
217 Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, Eigen::DenseIndex>,
218 Eigen::Aligned>
219 ConstTensor;
220
221 Eigen::array<int, NDIMS> p;
222 Eigen::DSizes<Eigen::DenseIndex, NDIMS> dims_eigen;
223 Eigen::DSizes<Eigen::DenseIndex, NDIMS> dims_out_eigen;
224 for (int i = 0; i < NDIMS; ++i) {
225 p[i] = permutation[i];
226 dims_eigen[i] = dims[i];
227 dims_out_eigen[i] = dims_out[i];
228 }
229 auto x = ConstTensor(input, dims_eigen);
230 auto y = Tensor(output, dims_out_eigen);
231 y = x.shuffle(p);
232 }
233
234 template <typename T>
TransposeUsingEigen(const T * input,T * output,absl::Span<int64_t const> dims,absl::Span<int64_t const> dims_out,absl::Span<int64_t const> permutation)235 void TransposeUsingEigen(const T* input, T* output,
236 absl::Span<int64_t const> dims,
237 absl::Span<int64_t const> dims_out,
238 absl::Span<int64_t const> permutation) {
239 switch (dims.size()) {
240 case 0:
241 return;
242 case 1:
243 TransposeUsingEigenNd<T, 1>(input, output, dims, dims_out, permutation);
244 return;
245 case 2:
246 TransposeUsingEigenNd<T, 2>(input, output, dims, dims_out, permutation);
247 return;
248 case 3:
249 TransposeUsingEigenNd<T, 3>(input, output, dims, dims_out, permutation);
250 return;
251 case 4:
252 TransposeUsingEigenNd<T, 4>(input, output, dims, dims_out, permutation);
253 return;
254 default:
255 LOG(FATAL) << "Unimplemented Eigen transpose rank";
256 }
257 }
258
259 struct TransposeTestCase {
TransposeTestCasexla::TransposeTestCase260 TransposeTestCase(std::vector<int64_t> dims, std::vector<int64_t> permutation,
261 std::vector<int64_t> input_tiling = {},
262 std::vector<int64_t> output_tiling = {})
263 : dims(std::move(dims)),
264 permutation(std::move(permutation)),
265 input_tiling(std::move(input_tiling)),
266 output_tiling(std::move(output_tiling)) {}
267
268 std::vector<int64_t> dims;
269 std::vector<int64_t> permutation;
270 std::vector<int64_t> input_tiling;
271 std::vector<int64_t> output_tiling;
272
ToStringxla::TransposeTestCase273 std::string ToString() const {
274 return absl::StrFormat(
275 "[%s],perm=[%s],tiling=[%s]/[%s]", absl::StrJoin(dims, ","),
276 absl::StrJoin(permutation, ","), absl::StrJoin(input_tiling, ","),
277 absl::StrJoin(output_tiling, ","));
278 }
279 };
280
operator <<(std::ostream & os,const TransposeTestCase & test)281 std::ostream& operator<<(std::ostream& os, const TransposeTestCase& test) {
282 os << test.ToString();
283 return os;
284 }
285
GetTransposeTestCases()286 std::vector<TransposeTestCase> GetTransposeTestCases() {
287 std::vector<TransposeTestCase> cases = {
288 TransposeTestCase(/*dims=*/{1}, /*permutation=*/{0}),
289 TransposeTestCase(/*dims=*/{4}, /*permutation=*/{0}),
290 TransposeTestCase(/*dims=*/{27}, /*permutation=*/{0}),
291 TransposeTestCase(/*dims=*/{1, 1}, /*permutation=*/{0, 1}),
292 TransposeTestCase(/*dims=*/{1, 1}, /*permutation=*/{1, 0}),
293 TransposeTestCase(/*dims=*/{2, 2}, /*permutation=*/{0, 1}),
294 TransposeTestCase(/*dims=*/{4, 4}, /*permutation=*/{1, 0}),
295 TransposeTestCase(/*dims=*/{4, 4}, /*permutation=*/{0, 1}),
296 TransposeTestCase(/*dims=*/{4, 4}, /*permutation=*/{1, 0}),
297 TransposeTestCase(/*dims=*/{8, 8}, /*permutation=*/{0, 1}),
298 TransposeTestCase(/*dims=*/{8, 8}, /*permutation=*/{1, 0}),
299 TransposeTestCase(/*dims=*/{16, 16}, /*permutation=*/{0, 1}),
300 TransposeTestCase(/*dims=*/{16, 16}, /*permutation=*/{1, 0}),
301 TransposeTestCase(/*dims=*/{11, 15}, /*permutation=*/{0, 1}),
302 TransposeTestCase(/*dims=*/{11, 15}, /*permutation=*/{1, 0}),
303 TransposeTestCase(/*dims=*/{11, 15, 13}, /*permutation=*/{0, 1, 2}),
304 TransposeTestCase(/*dims=*/{11, 15, 13}, /*permutation=*/{0, 2, 1}),
305 TransposeTestCase(/*dims=*/{11, 15, 13}, /*permutation=*/{1, 2, 0}),
306 TransposeTestCase(/*dims=*/{11, 15, 13}, /*permutation=*/{1, 0, 2}),
307 TransposeTestCase(/*dims=*/{11, 15, 13}, /*permutation=*/{2, 0, 1}),
308 TransposeTestCase(/*dims=*/{64, 64, 64}, /*permutation=*/{2, 1, 0}),
309 TransposeTestCase(/*dims=*/{256, 256, 256}, /*permutation=*/{2, 1, 0}),
310 TransposeTestCase(/*dims=*/{4, 8, 16, 32}, /*permutation=*/{3, 1, 0, 2}),
311 TransposeTestCase(/*dims=*/{64, 224, 224, 3},
312 /*permutation=*/{3, 1, 2, 0}),
313
314 TransposeTestCase(/*dims=*/{3}, /*permutation=*/{0},
315 /*input_tiling=*/{3}),
316 TransposeTestCase(/*dims=*/{3}, /*permutation=*/{0},
317 /*input_tiling=*/{},
318 /*output_tiling=*/{3}),
319 TransposeTestCase(/*dims=*/{2, 4, 6}, /*permutation=*/{0, 1, 2},
320 /*input_tiling=*/{},
321 /*output_tiling=*/{2, 3}),
322 TransposeTestCase(/*dims=*/{4}, /*permutation=*/{0},
323 /*input_tiling=*/{3}),
324 TransposeTestCase(/*dims=*/{5}, /*permutation=*/{0},
325 /*input_tiling=*/{},
326 /*output_tiling=*/{3}),
327 TransposeTestCase(/*dims=*/{8}, /*permutation=*/{0},
328 /*input_tiling=*/{},
329 /*output_tiling=*/{3}),
330 TransposeTestCase(/*dims=*/{8}, /*permutation=*/{0},
331 /*input_tiling=*/{3},
332 /*output_tiling=*/{}),
333 TransposeTestCase(/*dims=*/{29}, /*permutation=*/{0},
334 /*input_tiling=*/{},
335 /*output_tiling=*/{3}),
336 TransposeTestCase(/*dims=*/{12, 7}, /*permutation=*/{1, 0},
337 /*input_tiling=*/{4}),
338 TransposeTestCase(/*dims=*/{12, 7}, /*permutation=*/{1, 0},
339 /*input_tiling=*/{}, /*output_tiling=*/{5}),
340 TransposeTestCase(/*dims=*/{12, 7}, /*permutation=*/{1, 0},
341 /*input_tiling=*/{2, 4}),
342 TransposeTestCase(/*dims=*/{12, 7}, /*permutation=*/{1, 0},
343 /*input_tiling=*/{}, /*output_tiling=*/{5, 2}),
344 TransposeTestCase(/*dims=*/{128, 224, 224, 3},
345 /*permutation=*/{3, 1, 2, 0},
346 /*input_tiling=*/{},
347 /*output_tiling=*/{8, 128}),
348 };
349 return cases;
350 }
351
352 class TransposeTest : public ::testing::TestWithParam<TransposeTestCase> {
353 protected:
354 template <typename T>
TestTranspose(int parallelism)355 void TestTranspose(int parallelism) {
356 const TransposeTestCase test = GetParam();
357 tensorflow::thread::ThreadPool threadpool(tensorflow::Env::Default(),
358 "Transpose", parallelism);
359 std::vector<int64_t> output_dims = Permute(test.dims, test.permutation);
360 TF_ASSERT_OK_AND_ASSIGN(
361 auto plan, TransposePlan::Create(
362 sizeof(T), test.dims, test.permutation,
363 TransposePlan::Tiling{test.input_tiling},
364 TransposePlan::Tiling{test.output_tiling},
365 TransposePlan::Transformation::kNone, parallelism));
366 VLOG(1) << plan->ToString();
367 xla::Array<T> untiled_input(test.dims);
368 untiled_input.FillIota(0);
369 xla::Array<T> expected_untiled_output(output_dims);
370 TransposeUsingEigen(untiled_input.data(), expected_untiled_output.data(),
371 test.dims, output_dims, test.permutation);
372
373 auto tiled_input = TileArray(untiled_input, test.input_tiling);
374 auto expected_tiled_output =
375 TileArray(expected_untiled_output, test.output_tiling);
376
377 std::vector<T> output(
378 SizeOfTiledArray(plan->OutputDims(), test.output_tiling), -1);
379 plan->Execute(
380 tiled_input.data(), output.data(),
381 [&](std::function<void()> fn) { threadpool.Schedule(std::move(fn)); });
382
383 EXPECT_EQ(expected_tiled_output, output);
384 }
385 };
386
TEST_P(TransposeTest,TransposeInt8)387 TEST_P(TransposeTest, TransposeInt8) { TestTranspose<int8_t>(1); }
TEST_P(TransposeTest,TransposeInt16)388 TEST_P(TransposeTest, TransposeInt16) { TestTranspose<int16_t>(1); }
TEST_P(TransposeTest,TransposeInt32)389 TEST_P(TransposeTest, TransposeInt32) { TestTranspose<int32_t>(1); }
TEST_P(TransposeTest,TransposeInt64)390 TEST_P(TransposeTest, TransposeInt64) { TestTranspose<int64_t>(1); }
TEST_P(TransposeTest,TransposeInt128)391 TEST_P(TransposeTest, TransposeInt128) { TestTranspose<absl::int128>(1); }
392
TEST_P(TransposeTest,ParallelTransposeInt8)393 TEST_P(TransposeTest, ParallelTransposeInt8) { TestTranspose<int8_t>(16); }
TEST_P(TransposeTest,ParallelTransposeInt32)394 TEST_P(TransposeTest, ParallelTransposeInt32) { TestTranspose<int32_t>(16); }
395
396 INSTANTIATE_TEST_SUITE_P(TransposeTestInstance, TransposeTest,
397 ::testing::ValuesIn(GetTransposeTestCases()));
398
TEST(TransposeTest,NegativeStrides1D)399 TEST(TransposeTest, NegativeStrides1D) {
400 int64_t n = 10;
401 std::vector<int32_t> input(n);
402 std::vector<int32_t> output(n);
403 std::vector<int32_t> expected(n);
404 absl::c_iota(input, int32_t{7});
405 std::iota(expected.rbegin(), expected.rend(), 7);
406 TF_ASSERT_OK_AND_ASSIGN(
407 auto plan, TransposePlan::Create(
408 sizeof(int32_t), {n}, /*permutation=*/{0},
409 TransposePlan::Striding{{-int64_t{sizeof(int32_t)}}}));
410 plan->Execute(input.data() + (n - 1), output.data());
411 EXPECT_EQ(expected, output);
412 }
413
TEST(TransposeTest,NegativeStrides2D)414 TEST(TransposeTest, NegativeStrides2D) {
415 xla::Array<int16_t> input = {
416 {1, 2, 3, 4},
417 {5, 6, 7, 8},
418 {9, 10, 11, 12},
419 };
420 xla::Array<int16_t> expected = {
421 {4, 8, 12},
422 {3, 7, 11},
423 {2, 6, 10},
424 {1, 5, 9},
425 };
426 xla::Array<int16_t> output({4, 3});
427 TF_ASSERT_OK_AND_ASSIGN(
428 auto plan, TransposePlan::Create(
429 sizeof(int16_t), {3, 4}, /*permutation=*/{1, 0},
430 TransposePlan::Striding{
431 {4 * sizeof(int16_t), -int64_t{sizeof(int16_t)}}}));
432 plan->Execute(input.data() + 3, output.data());
433 EXPECT_EQ(expected, output);
434 }
435
BenchmarkCases()436 static std::vector<TransposeTestCase> BenchmarkCases() {
437 return std::vector<TransposeTestCase>{
438 TransposeTestCase(/*dims=*/{256, 256},
439 /*permutation=*/{1, 0}),
440 TransposeTestCase(/*dims=*/{512, 512},
441 /*permutation=*/{1, 0}),
442 TransposeTestCase(/*dims=*/{1024, 1024},
443 /*permutation=*/{1, 0}),
444 TransposeTestCase(/*dims=*/{256, 256, 256},
445 /*permutation=*/{0, 2, 1}),
446 TransposeTestCase(/*dims=*/{256, 256, 256},
447 /*permutation=*/{1, 0, 2}),
448 TransposeTestCase(/*dims=*/{256, 256, 256},
449 /*permutation=*/{1, 2, 0}),
450 TransposeTestCase(/*dims=*/{256, 256, 256},
451 /*permutation=*/{2, 0, 1}),
452 TransposeTestCase(/*dims=*/{256, 256, 256},
453 /*permutation=*/{2, 1, 0}),
454 TransposeTestCase(/*dims=*/{512, 512, 512},
455 /*permutation=*/{0, 2, 1}),
456 TransposeTestCase(/*dims=*/{512, 512, 512},
457 /*permutation=*/{1, 0, 2}),
458 TransposeTestCase(/*dims=*/{512, 512, 512},
459 /*permutation=*/{1, 2, 0}),
460 TransposeTestCase(/*dims=*/{512, 512, 512},
461 /*permutation=*/{2, 0, 1}),
462 TransposeTestCase(/*dims=*/{512, 512, 512},
463 /*permutation=*/{2, 1, 0}),
464 TransposeTestCase(/*dims=*/{64, 224, 224, 3},
465 /*permutation=*/{1, 2, 3, 0}),
466 TransposeTestCase(/*dims=*/{256, 64, 64, 3},
467 /*permutation=*/{1, 3, 2, 0}),
468 };
469 }
470
471 template <typename T>
BM_Eigen(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)472 void BM_Eigen(const TransposeTestCase& bm, int parallelism,
473 ::testing::benchmark::State& state) {
474 CHECK_EQ(parallelism, 1);
475 Array<T> input(bm.dims);
476 input.FillIota(0);
477 std::vector<int64_t> output_dims = Permute(bm.dims, bm.permutation);
478 Array<T> output(output_dims);
479 for (auto s : state) {
480 TransposeUsingEigen(input.data(), output.data(), bm.dims, output_dims,
481 bm.permutation);
482 tensorflow::testing::DoNotOptimize(output);
483 }
484 }
BM_Eigen_uint8(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)485 static void BM_Eigen_uint8(const TransposeTestCase& bm, int parallelism,
486 ::testing::benchmark::State& state) {
487 BM_Eigen<uint8_t>(std::move(bm), parallelism, state);
488 }
BM_Eigen_float(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)489 static void BM_Eigen_float(const TransposeTestCase& bm, int parallelism,
490 ::testing::benchmark::State& state) {
491 BM_Eigen<float>(bm, parallelism, state);
492 }
493
494 template <typename T>
BM_Transpose(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)495 void BM_Transpose(const TransposeTestCase& bm, int parallelism,
496 ::testing::benchmark::State& state) {
497 TF_ASSERT_OK_AND_ASSIGN(
498 auto plan,
499 TransposePlan::Create(sizeof(T), bm.dims, bm.permutation,
500 TransposePlan::Tiling{}, TransposePlan::Tiling{},
501 TransposePlan::Transformation::kNone, parallelism));
502 Array<T> input(bm.dims);
503 input.FillIota(0);
504 std::vector<int64_t> output_dims = Permute(bm.dims, bm.permutation);
505 Array<T> output(output_dims);
506 tensorflow::thread::ThreadPool threadpool(tensorflow::Env::Default(),
507 "Transpose", parallelism);
508 for (auto s : state) {
509 plan->Execute(input.data(), output.data(), [&](std::function<void()> fn) {
510 threadpool.Schedule(std::move(fn));
511 });
512 tensorflow::testing::DoNotOptimize(output);
513 }
514 }
BM_Transpose_uint8(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)515 static void BM_Transpose_uint8(const TransposeTestCase& bm, int parallelism,
516 ::testing::benchmark::State& state) {
517 BM_Transpose<uint8_t>(bm, parallelism, state);
518 }
BM_Transpose_float(const TransposeTestCase & bm,int parallelism,::testing::benchmark::State & state)519 static void BM_Transpose_float(const TransposeTestCase& bm, int parallelism,
520 ::testing::benchmark::State& state) {
521 BM_Transpose<float>(bm, parallelism, state);
522 }
523
__anonb8a511080302() 524 static void* benchmarks = []() {
525 using BenchmarkFn =
526 void (*)(const TransposeTestCase&, int, testing::benchmark::State&);
527 std::vector<std::tuple<std::string, BenchmarkFn, std::vector<int>>> variants =
528 {
529 {"BM_Eigen_uint8", BM_Eigen_uint8, {1}},
530 {"BM_Transpose_uint8", BM_Transpose_uint8, {1, 4, 8}}, //
531 {"BM_Eigen_float", BM_Eigen_float, {1}},
532 {"BM_Transpose_float", BM_Transpose_float, {1, 4, 8}}, //
533 };
534 auto benchmark_cases = BenchmarkCases();
535 for (const auto& benchmark_case : benchmark_cases) {
536 for (const auto& variant : variants) {
537 for (int num_threads : std::get<2>(variant)) {
538 std::string name =
539 absl::StrCat(std::get<0>(variant), "_threads_", num_threads, "_",
540 absl::StrJoin(benchmark_case.dims, "_"), "_perm_",
541 absl::StrJoin(benchmark_case.permutation, "_"));
542
543 TransposeTestCase testcase = benchmark_case;
544 BenchmarkFn fn = std::get<1>(variant);
545 benchmark::RegisterBenchmark(
546 name.c_str(), [fn, num_threads, testcase](benchmark::State& state) {
547 fn(testcase, num_threads, state);
548 });
549 }
550 }
551 }
552 return nullptr;
553 }();
554
TEST(TransposePlanCache,Basics)555 TEST(TransposePlanCache, Basics) {
556 TransposePlanCache cache(2);
557 TF_ASSERT_OK_AND_ASSIGN(
558 auto p1, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3},
559 /*permutation=*/{2, 1, 0}));
560 TF_ASSERT_OK_AND_ASSIGN(
561 auto p1a, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3},
562 /*permutation=*/{2, 1, 0}));
563 EXPECT_TRUE(p1.get() == p1a.get());
564 TF_ASSERT_OK_AND_ASSIGN(
565 auto p2, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3},
566 /*permutation=*/{1, 2, 0}));
567 EXPECT_TRUE(p1.get() != p2.get());
568 TF_ASSERT_OK_AND_ASSIGN(
569 auto p3, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3},
570 /*permutation=*/{0, 1, 2}));
571 EXPECT_TRUE(p3.get() != p1.get());
572 TF_ASSERT_OK_AND_ASSIGN(
573 auto p1b, cache.GetOrCreate(/*elem_size_in_bytes=*/4, /*dims=*/{1, 2, 3},
574 /*permutation=*/{2, 1, 0}));
575 EXPECT_TRUE(p1.get() != p1b.get());
576 }
577
578 } // namespace xla
579