1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gather_expander.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_query.h"
19 #include "tensorflow/compiler/xla/test.h"
20 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
21 #include "tensorflow/compiler/xla/tests/test_macros.h"
22
23 namespace xla {
24 namespace {
25
26 using GatherExpanderTest = HloTestBase;
27
TEST_F(GatherExpanderTest,ErrorStatusOnTooManyIndices)28 TEST_F(GatherExpanderTest, ErrorStatusOnTooManyIndices) {
29 const string hlo_text = R"(
30 HloModule TensorFlowGatherMultipleBatchDims
31
32 ENTRY main {
33 operand = s32[3,3] parameter(0)
34 indices = s32[2147483647,5] parameter(1)
35 ROOT gather = s32[2147483647,3,5] gather(operand, indices),
36 offset_dims={1},
37 collapsed_slice_dims={1},
38 start_index_map={1},
39 index_vector_dim=2,
40 slice_sizes={3, 1}
41 }
42 )";
43 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
44 ParseAndReturnVerifiedModule(hlo_text));
45
46 Status status = GatherExpander{GatherExpander::kEliminateAllGathers}
47 .Run(module.get())
48 .status();
49 EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
50
51 ASSERT_THAT(
52 status.error_message(),
53 ::testing::HasSubstr("Gather operations with more than 2147483647 gather "
54 "indices are not supported."));
55 }
56
TEST_F(GatherExpanderTest,AvoidDegenerateDims)57 TEST_F(GatherExpanderTest, AvoidDegenerateDims) {
58 const string hlo_text = R"(
59 HloModule TensorFlowGatherV2
60
61 ENTRY main {
62 operand = s32[3,3] parameter(0)
63 indices = s32[2] parameter(1)
64 ROOT gather = s32[3,2] gather(operand, indices),
65 offset_dims={0},
66 collapsed_slice_dims={1},
67 start_index_map={1},
68 index_vector_dim=1,
69 slice_sizes={3, 1}
70 }
71 )";
72 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
73 ParseAndReturnVerifiedModule(hlo_text));
74 TF_ASSERT_OK_AND_ASSIGN(
75 bool changed,
76 GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
77 ASSERT_TRUE(changed);
78
79 HloInstruction* while_instr = nullptr;
80 for (auto* instr : module->entry_computation()->instructions()) {
81 if (instr->opcode() == HloOpcode::kWhile) {
82 ASSERT_EQ(while_instr, nullptr)
83 << "Expected exactly one while instruction in the entry computation "
84 "after gather expansion";
85 while_instr = instr;
86 }
87 }
88
89 ASSERT_NE(while_instr, nullptr)
90 << "Expected exactly one while instruction in the entry computation "
91 "after gather expansion";
92
93 // We want to avoid create while loop with shapes that have degenerate
94 // dimensions for TF gather. In this case we expect the loop state to be of
95 // the shape (sNN[], s32[3,3]{1,0}, s32[2]{0}, s32[2,3]{1,0}). The leading
96 // sNN is an implementation detail from WhileUtil::MakeCountedLoop so we don't
97 // check it here (though in theory the form of the while loop state is itself
98 // an implementation detail from WhileUtil::MakeCountedLoop).
99
100 const Shape& while_shape = while_instr->shape();
101 ASSERT_TRUE(while_shape.IsTuple());
102 ASSERT_EQ(ShapeUtil::TupleElementCount(while_shape), 4);
103
104 EXPECT_TRUE(ShapeUtil::SameDimensions(
105 ShapeUtil::MakeShape(S32, {3, 3}),
106 ShapeUtil::GetTupleElementShape(while_shape, 1)));
107
108 EXPECT_TRUE(ShapeUtil::SameDimensions(
109 ShapeUtil::MakeShape(S32, {2}),
110 ShapeUtil::GetTupleElementShape(while_shape, 2)));
111
112 EXPECT_TRUE(ShapeUtil::SameDimensions(
113 ShapeUtil::MakeShape(S32, {2, 3}),
114 ShapeUtil::GetTupleElementShape(while_shape, 3)));
115 }
116
TEST_F(GatherExpanderTest,CheckOpMetadata)117 TEST_F(GatherExpanderTest, CheckOpMetadata) {
118 const string hlo_text = R"(
119 HloModule TensorFlowGatherV2
120
121 ENTRY main {
122 operand = s32[3,3] parameter(0)
123 indices = s32[2] parameter(1)
124 ROOT gather = s32[3,2] gather(operand, indices),
125 offset_dims={0},
126 collapsed_slice_dims={1},
127 start_index_map={1},
128 index_vector_dim=1,
129 slice_sizes={3, 1}
130 }
131 )";
132 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
133 ParseAndReturnVerifiedModule(hlo_text));
134 OpMetadata metadata;
135 metadata.set_op_name("Gather");
136 module->entry_computation()->root_instruction()->set_metadata(metadata);
137 TF_ASSERT_OK_AND_ASSIGN(
138 bool changed,
139 GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
140 ASSERT_TRUE(changed);
141
142 HloInstruction* while_instr = nullptr;
143 for (auto* instr : module->entry_computation()->instructions()) {
144 if (instr->opcode() == HloOpcode::kWhile) {
145 ASSERT_EQ(while_instr, nullptr)
146 << "Expected exactly one while instruction in the entry computation "
147 "after gather expansion";
148 while_instr = instr;
149 }
150 }
151
152 ASSERT_NE(while_instr, nullptr)
153 << "Expected exactly one while instruction in the entry computation "
154 "after gather expansion";
155 EXPECT_EQ(while_instr->metadata().op_name(), "Gather");
156 }
157
TEST_F(GatherExpanderTest,EliminateSimpleGathersSkipsNontrivialGather)158 TEST_F(GatherExpanderTest, EliminateSimpleGathersSkipsNontrivialGather) {
159 const string hlo_text = R"(
160 HloModule TensorFlowGatherV1
161
162 ENTRY main {
163 operand = s32[3,3] parameter(0)
164 indices = s32[2] parameter(1)
165 ROOT gather = s32[2,3] gather(operand, indices),
166 offset_dims={1},
167 collapsed_slice_dims={0},
168 start_index_map={0},
169 index_vector_dim=1,
170 slice_sizes={1, 3}
171 }
172 )";
173
174 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
175 ParseAndReturnVerifiedModule(hlo_text));
176 GatherExpander pass(GatherExpander::kEliminateSimpleGathers);
177 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
178 ASSERT_FALSE(changed);
179 }
180
TEST_F(GatherExpanderTest,EliminateSimpleGathersRewritesTrivialGather)181 TEST_F(GatherExpanderTest, EliminateSimpleGathersRewritesTrivialGather) {
182 const string hlo_text = R"(
183 HloModule test
184
185 ENTRY main {
186 operand = s32[100] parameter(0)
187 indices = s32[1] parameter(1)
188 ROOT gather = s32[10] gather(operand, indices),
189 offset_dims={0},
190 collapsed_slice_dims={},
191 start_index_map={0},
192 index_vector_dim=0,
193 slice_sizes={10}
194 }
195 )";
196
197 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
198 ParseAndReturnVerifiedModule(hlo_text));
199 GatherExpander pass(GatherExpander::kEliminateAllGathers);
200 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
201 ASSERT_TRUE(changed);
202 ASSERT_FALSE(hlo_query::ContainsInstrWithOpcode(module->entry_computation(),
203 {HloOpcode::kGather}));
204 }
205
206 } // namespace
207 } // namespace xla
208