• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bitcast_decomposer.h"
17 
18 #include <string>
19 #include <tuple>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/random/random.h"
24 #include "absl/strings/substitute.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/permutation_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_parser.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/test_utils.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 
34 namespace xla {
35 namespace {
36 
37 namespace m = ::xla::match;
38 
AllPermutationsOfShape(const Shape & s)39 std::vector<Shape> AllPermutationsOfShape(const Shape& s) {
40   std::vector<int64_t> dims_perm(s.dimensions_size());
41   absl::c_iota(dims_perm, 0);
42   std::vector<int64_t> layout_perm(s.dimensions_size());
43   absl::c_iota(layout_perm, 0);
44 
45   std::vector<Shape> ret;
46   do {
47     do {
48       Shape new_shape = ShapeUtil::MakeShapeWithLayout(
49           s.element_type(),  //
50           ComposePermutations(s.dimensions(), dims_perm),
51           ComposePermutations(s.layout().minor_to_major(), layout_perm));
52       ret.push_back(new_shape);
53     } while (absl::c_next_permutation(layout_perm));
54   } while (absl::c_next_permutation(dims_perm));
55   return ret;
56 }
57 
AllPermutationsOfShapes(std::vector<std::vector<int64_t>> dims_list)58 std::vector<Shape> AllPermutationsOfShapes(
59     std::vector<std::vector<int64_t>> dims_list) {
60   std::vector<Shape> ret;
61   for (const auto& dims : dims_list) {
62     std::vector<Shape> perms = AllPermutationsOfShape(
63         ShapeUtil::MakeShapeWithDescendingLayout(F32, dims));
64     ret.insert(ret.end(), perms.begin(), perms.end());
65   }
66   return ret;
67 }
68 
69 class BitcastDecomposerParameterizedTest
70     : public HloTestBase,
71       public ::testing::WithParamInterface<
72           std::tuple<Shape /*src*/, Shape /*dst*/>> {
73  public:
BitcastDecomposerParameterizedTest()74   BitcastDecomposerParameterizedTest()
75       : HloTestBase(/*verifier_layout_sensitive=*/false,
76                     /*allow_mixed_precision_in_hlo_verifier=*/false) {}
77 
78  protected:
79   absl::BitGen rand_;
80 };
81 
82 INSTANTIATE_TEST_SUITE_P(
83     Handcrafted, BitcastDecomposerParameterizedTest,
84     ::testing::Values(std::make_tuple(
85         ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 4, 2, 2}, {2, 4, 3, 1, 0}),
86         ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 2, 2, 4},
87                                        {4, 3, 2, 1, 0}))));
88 
89 // Skip most tests in sanitizer/debug builds, otherwise this times out.
90 #if !defined(ADDRESS_SANITIZER) && !defined(MEMORY_SANITIZER) && \
91     !defined(THREAD_SANITIZER) && defined(NDEBUG)
92 INSTANTIATE_TEST_SUITE_P(
93     Combinatorial, BitcastDecomposerParameterizedTest,
94     ::testing::Combine(
95         // src shapes
96         ::testing::Values(
97             ShapeUtil::MakeShapeWithDescendingLayout(F32, {4, 10, 100}),
98             ShapeUtil::MakeShapeWithDescendingLayout(F32, {4, 1, 10, 100}),
99             ShapeUtil::MakeShapeWithLayout(F32, {4, 10, 100}, {0, 2, 1})),
100         // dst shapes
101         ::testing::ValuesIn(AllPermutationsOfShapes({
102             // Original shape without degenerate dims.
103             {4, 10, 100},
104             // Original shape with degenerate dims.
105             {1, 4, 10, 100},
106             // Redistributing elements between dims while maintaining the same
107             // rank.
108             {2, 20, 100},
109             {40, 10, 10},
110             {1, 40, 10, 10},
111             // Merging dims without redistributing elements between dims.
112             {40, 100},
113             {400, 10},
114             {4000},
115             {1, 4000},
116             // Merging dims and redistributing elements between dims.
117             {20, 200},
118             // Splitting dims without redistributing elements between dims.
119             {2, 2, 10, 100},
120             {4, 2, 5, 100},
121             // Splitting dims and redistributing between the dims.
122             {2, 5, 5, 80},
123         }))));
124 #endif
125 
TEST_P(BitcastDecomposerParameterizedTest,DoIt)126 TEST_P(BitcastDecomposerParameterizedTest, DoIt) {
127   auto [src, dst] = GetParam();
128 
129   const char* const kModuleTemplate = R"(
130   HloModule module
131 
132   fused_comp {
133     lhs  = $0 parameter(0)
134     ROOT root = $1 bitcast(lhs)
135   }
136 
137   ENTRY main {
138     ROOT fusion = $1 fusion($0 parameter(0)), kind=kLoop, calls=fused_comp
139   })";
140   std::string module_string =
141       absl::Substitute(kModuleTemplate, ShapeUtil::HumanStringWithLayout(src),
142                        ShapeUtil::HumanStringWithLayout(dst));
143 
144   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
145                           ParseAndReturnVerifiedModule(module_string));
146 
147   // Actually compiling and running the module is expensive; we can't afford to
148   // do it for all 9000 (as of writing) tests. Pick a random 1% of them to
149   // execute.
150   bool execute_module = absl::Bernoulli(this->rand_, 0.01);
151   Literal param, expected_val;
152   if (execute_module) {
153     param = MakeFakeLiteral(src).ValueOrDie();
154     expected_val = ExecuteNoHloPasses(module->Clone(), {&param});
155   }
156 
157   BitcastDecomposer pass;
158   TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
159   SCOPED_TRACE(module->ToString());
160   if (!changed) {
161     // The pass shouldn't change the bitcast if and only if it's already a
162     // reshape-is-bitcast.
163     EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(src, dst));
164     return;
165   }
166 
167   // The result must be of the form transpose(bitcast(transpose(param))), except
168   // that any of these operations can be skipped.
169   const HloInstruction* root = module->entry_computation()
170                                    ->root_instruction()
171                                    ->fused_instructions_computation()
172                                    ->root_instruction();
173   const HloInstruction* bitcast = nullptr;
174   const HloInstruction* transpose1 = nullptr;
175   const HloInstruction* transpose2 = nullptr;
176   ASSERT_THAT(
177       root, ::testing::AnyOf(
178                 GmockMatch(m::Bitcast(&bitcast, m::Parameter(0))),
179                 GmockMatch(m::Transpose(&transpose1, m::Parameter(0))),
180                 GmockMatch(m::Transpose(&transpose1,
181                                         m::Bitcast(&bitcast, m::Parameter(0)))),
182                 GmockMatch(m::Bitcast(
183                     &bitcast, m::Transpose(&transpose1, m::Parameter(0)))),
184                 GmockMatch(m::Transpose(
185                     &transpose2,
186                     m::Bitcast(&bitcast,
187                                m::Transpose(&transpose1, m::Parameter(0)))))));
188   if (bitcast != nullptr) {
189     EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(bitcast->operand(0)->shape(),
190                                             bitcast->shape()));
191   }
192   if (transpose1 != nullptr) {
193     EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose1->operand(0)->shape(),
194                                               transpose1->shape(),
195                                               transpose1->dimensions()));
196   }
197   if (transpose2 != nullptr) {
198     EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose2->operand(0)->shape(),
199                                               transpose2->shape(),
200                                               transpose2->dimensions()));
201   }
202 
203   if (execute_module) {
204     auto actual_val = ExecuteNoHloPasses(module->Clone(), {&param});
205     EXPECT_TRUE(LiteralTestUtil::Equal(expected_val, actual_val));
206   }
207 }
208 
209 }  // anonymous namespace
210 }  // namespace xla
211