1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bitcast_decomposer.h"
17
18 #include <string>
19 #include <tuple>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/random/random.h"
24 #include "absl/strings/substitute.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/permutation_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_parser.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/test_utils.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33
34 namespace xla {
35 namespace {
36
37 namespace m = ::xla::match;
38
AllPermutationsOfShape(const Shape & s)39 std::vector<Shape> AllPermutationsOfShape(const Shape& s) {
40 std::vector<int64_t> dims_perm(s.dimensions_size());
41 absl::c_iota(dims_perm, 0);
42 std::vector<int64_t> layout_perm(s.dimensions_size());
43 absl::c_iota(layout_perm, 0);
44
45 std::vector<Shape> ret;
46 do {
47 do {
48 Shape new_shape = ShapeUtil::MakeShapeWithLayout(
49 s.element_type(), //
50 ComposePermutations(s.dimensions(), dims_perm),
51 ComposePermutations(s.layout().minor_to_major(), layout_perm));
52 ret.push_back(new_shape);
53 } while (absl::c_next_permutation(layout_perm));
54 } while (absl::c_next_permutation(dims_perm));
55 return ret;
56 }
57
AllPermutationsOfShapes(std::vector<std::vector<int64_t>> dims_list)58 std::vector<Shape> AllPermutationsOfShapes(
59 std::vector<std::vector<int64_t>> dims_list) {
60 std::vector<Shape> ret;
61 for (const auto& dims : dims_list) {
62 std::vector<Shape> perms = AllPermutationsOfShape(
63 ShapeUtil::MakeShapeWithDescendingLayout(F32, dims));
64 ret.insert(ret.end(), perms.begin(), perms.end());
65 }
66 return ret;
67 }
68
69 class BitcastDecomposerParameterizedTest
70 : public HloTestBase,
71 public ::testing::WithParamInterface<
72 std::tuple<Shape /*src*/, Shape /*dst*/>> {
73 public:
BitcastDecomposerParameterizedTest()74 BitcastDecomposerParameterizedTest()
75 : HloTestBase(/*verifier_layout_sensitive=*/false,
76 /*allow_mixed_precision_in_hlo_verifier=*/false) {}
77
78 protected:
79 absl::BitGen rand_;
80 };
81
82 INSTANTIATE_TEST_SUITE_P(
83 Handcrafted, BitcastDecomposerParameterizedTest,
84 ::testing::Values(std::make_tuple(
85 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 4, 2, 2}, {2, 4, 3, 1, 0}),
86 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 2, 2, 4},
87 {4, 3, 2, 1, 0}))));
88
89 // Skip most tests in sanitizer/debug builds, otherwise this times out.
90 #if !defined(ADDRESS_SANITIZER) && !defined(MEMORY_SANITIZER) && \
91 !defined(THREAD_SANITIZER) && defined(NDEBUG)
92 INSTANTIATE_TEST_SUITE_P(
93 Combinatorial, BitcastDecomposerParameterizedTest,
94 ::testing::Combine(
95 // src shapes
96 ::testing::Values(
97 ShapeUtil::MakeShapeWithDescendingLayout(F32, {4, 10, 100}),
98 ShapeUtil::MakeShapeWithDescendingLayout(F32, {4, 1, 10, 100}),
99 ShapeUtil::MakeShapeWithLayout(F32, {4, 10, 100}, {0, 2, 1})),
100 // dst shapes
101 ::testing::ValuesIn(AllPermutationsOfShapes({
102 // Original shape without degenerate dims.
103 {4, 10, 100},
104 // Original shape with degenerate dims.
105 {1, 4, 10, 100},
106 // Redistributing elements between dims while maintaining the same
107 // rank.
108 {2, 20, 100},
109 {40, 10, 10},
110 {1, 40, 10, 10},
111 // Merging dims without redistributing elements between dims.
112 {40, 100},
113 {400, 10},
114 {4000},
115 {1, 4000},
116 // Merging dims and redistributing elements between dims.
117 {20, 200},
118 // Splitting dims without redistributing elements between dims.
119 {2, 2, 10, 100},
120 {4, 2, 5, 100},
121 // Splitting dims and redistributing between the dims.
122 {2, 5, 5, 80},
123 }))));
124 #endif
125
TEST_P(BitcastDecomposerParameterizedTest,DoIt)126 TEST_P(BitcastDecomposerParameterizedTest, DoIt) {
127 auto [src, dst] = GetParam();
128
129 const char* const kModuleTemplate = R"(
130 HloModule module
131
132 fused_comp {
133 lhs = $0 parameter(0)
134 ROOT root = $1 bitcast(lhs)
135 }
136
137 ENTRY main {
138 ROOT fusion = $1 fusion($0 parameter(0)), kind=kLoop, calls=fused_comp
139 })";
140 std::string module_string =
141 absl::Substitute(kModuleTemplate, ShapeUtil::HumanStringWithLayout(src),
142 ShapeUtil::HumanStringWithLayout(dst));
143
144 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
145 ParseAndReturnVerifiedModule(module_string));
146
147 // Actually compiling and running the module is expensive; we can't afford to
148 // do it for all 9000 (as of writing) tests. Pick a random 1% of them to
149 // execute.
150 bool execute_module = absl::Bernoulli(this->rand_, 0.01);
151 Literal param, expected_val;
152 if (execute_module) {
153 param = MakeFakeLiteral(src).ValueOrDie();
154 expected_val = ExecuteNoHloPasses(module->Clone(), {¶m});
155 }
156
157 BitcastDecomposer pass;
158 TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
159 SCOPED_TRACE(module->ToString());
160 if (!changed) {
161 // The pass shouldn't change the bitcast if and only if it's already a
162 // reshape-is-bitcast.
163 EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(src, dst));
164 return;
165 }
166
167 // The result must be of the form transpose(bitcast(transpose(param))), except
168 // that any of these operations can be skipped.
169 const HloInstruction* root = module->entry_computation()
170 ->root_instruction()
171 ->fused_instructions_computation()
172 ->root_instruction();
173 const HloInstruction* bitcast = nullptr;
174 const HloInstruction* transpose1 = nullptr;
175 const HloInstruction* transpose2 = nullptr;
176 ASSERT_THAT(
177 root, ::testing::AnyOf(
178 GmockMatch(m::Bitcast(&bitcast, m::Parameter(0))),
179 GmockMatch(m::Transpose(&transpose1, m::Parameter(0))),
180 GmockMatch(m::Transpose(&transpose1,
181 m::Bitcast(&bitcast, m::Parameter(0)))),
182 GmockMatch(m::Bitcast(
183 &bitcast, m::Transpose(&transpose1, m::Parameter(0)))),
184 GmockMatch(m::Transpose(
185 &transpose2,
186 m::Bitcast(&bitcast,
187 m::Transpose(&transpose1, m::Parameter(0)))))));
188 if (bitcast != nullptr) {
189 EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(bitcast->operand(0)->shape(),
190 bitcast->shape()));
191 }
192 if (transpose1 != nullptr) {
193 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose1->operand(0)->shape(),
194 transpose1->shape(),
195 transpose1->dimensions()));
196 }
197 if (transpose2 != nullptr) {
198 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose2->operand(0)->shape(),
199 transpose2->shape(),
200 transpose2->dimensions()));
201 }
202
203 if (execute_module) {
204 auto actual_val = ExecuteNoHloPasses(module->Clone(), {¶m});
205 EXPECT_TRUE(LiteralTestUtil::Equal(expected_val, actual_val));
206 }
207 }
208
209 } // anonymous namespace
210 } // namespace xla
211