1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31
32 namespace op = xla::testing::opcode_matchers;
33
34 namespace xla {
35 namespace {
36
37 class TupleSimplifierTest : public HloTestBase {
38 protected:
Run(HloModule * module,bool change_expected)39 void Run(HloModule* module, bool change_expected) {
40 auto changed_status = RunHloPass(TupleSimplifier(), module);
41 TF_ASSERT_OK(changed_status.status());
42 EXPECT_EQ(change_expected, changed_status.ValueOrDie());
43 }
Run(HloModule * module,bool change_expected,bool exclude_entry)44 void Run(HloModule* module, bool change_expected, bool exclude_entry) {
45 auto changed_status = RunHloPass(TupleSimplifier(exclude_entry), module);
46 TF_ASSERT_OK(changed_status.status());
47 EXPECT_EQ(change_expected, changed_status.ValueOrDie());
48 }
49
50 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
51 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
52 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {}),
53 ShapeUtil::MakeShape(F32, {})});
54 };
55
TEST_F(TupleSimplifierTest,TupleOfParameters)56 TEST_F(TupleSimplifierTest, TupleOfParameters) {
57 // A Tuple constructed of a bunch of parameters should not be changed.
58 HloComputation::Builder builder(TestName());
59 HloInstruction* param0 = builder.AddInstruction(
60 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
61 HloInstruction* param1 = builder.AddInstruction(
62 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
63 HloInstruction* param2 = builder.AddInstruction(
64 HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
65 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2}));
66 auto module = CreateNewVerifiedModule();
67 module->AddEntryComputation(builder.Build());
68
69 Run(module.get(), /*change_expected=*/false);
70 }
71
TEST_F(TupleSimplifierTest,GteOfTupleOfParameter)72 TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
73 // A GTE of a tuple parameter should not be changed.
74 HloComputation::Builder builder(TestName());
75 HloInstruction* param = builder.AddInstruction(
76 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
77 builder.AddInstruction(
78 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
79 auto module = CreateNewVerifiedModule();
80 module->AddEntryComputation(builder.Build());
81
82 Run(module.get(), /*change_expected=*/false);
83 }
84
TEST_F(TupleSimplifierTest,GteOfTuple)85 TEST_F(TupleSimplifierTest, GteOfTuple) {
86 // A GTE of a Tuple should be short-circuited.
87 HloComputation::Builder builder(TestName());
88 HloInstruction* param0 = builder.AddInstruction(
89 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
90 HloInstruction* param1 = builder.AddInstruction(
91 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
92 HloInstruction* param2 = builder.AddInstruction(
93 HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
94 HloInstruction* tuple = builder.AddInstruction(
95 HloInstruction::CreateTuple({param0, param1, param2}));
96 HloInstruction* gte = builder.AddInstruction(
97 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
98
99 auto module = CreateNewVerifiedModule();
100 auto computation = module->AddEntryComputation(builder.Build());
101
102 EXPECT_THAT(computation->root_instruction(), gte);
103
104 Run(module.get(), /*change_expected=*/true);
105
106 EXPECT_THAT(computation->root_instruction(), param1);
107 }
108
TEST_F(TupleSimplifierTest,GteOfTupleChain)109 TEST_F(TupleSimplifierTest, GteOfTupleChain) {
110 // Verify a chain of GTE/Tuple instructions is collapsed.
111 HloComputation::Builder builder(TestName());
112 HloInstruction* param = builder.AddInstruction(
113 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
114
115 const int kChainLength = 10;
116 HloInstruction* element = param;
117 for (int i = 0; i < kChainLength; ++i) {
118 HloInstruction* tuple = builder.AddInstruction(
119 HloInstruction::CreateTuple({element, element, element}));
120 element = builder.AddInstruction(
121 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
122 }
123 builder.AddInstruction(
124 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element));
125
126 auto module = CreateNewVerifiedModule();
127 auto computation = module->AddEntryComputation(builder.Build());
128
129 EXPECT_THAT(computation->root_instruction(),
130 op::Negate(op::GetTupleElement(op::Tuple())));
131
132 Run(module.get(), /*change_expected=*/true);
133
134 EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
135 }
136
TEST_F(TupleSimplifierTest,NestedGteOfTuples)137 TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
138 // Verify a nesting of GTE/Tuple instructions is collapsed. Tuples are nested
139 // to some depth with a chain of Tuple instructions, then extracted with a
140 // chain of GTE instructions.
141 HloComputation::Builder builder(TestName());
142 HloInstruction* param = builder.AddInstruction(
143 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
144
145 const int kNestingDepth = 5;
146 HloInstruction* nested_tuple = param;
147 for (int i = 0; i < kNestingDepth; ++i) {
148 nested_tuple = builder.AddInstruction(
149 HloInstruction::CreateTuple({nested_tuple, nested_tuple}));
150 }
151
152 HloInstruction* element = nested_tuple;
153 for (int i = 0; i < kNestingDepth; ++i) {
154 element = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
155 ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0));
156 }
157
158 auto module = CreateNewVerifiedModule();
159 auto computation = module->AddEntryComputation(builder.Build());
160
161 EXPECT_THAT(computation->root_instruction(), element);
162
163 Run(module.get(), /*change_expected=*/true);
164
165 EXPECT_THAT(computation->root_instruction(), param);
166 }
167
TEST_F(TupleSimplifierTest,TupleOfGteInstructions)168 TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
169 // Verify that a tuple constructed of GTE instructions operating on the same
170 // tuple are collapsed.
171 HloComputation::Builder builder(TestName());
172 HloInstruction* tuple_param = builder.AddInstruction(
173 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
174 HloInstruction* gte0 = builder.AddInstruction(
175 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
176 HloInstruction* gte1 = builder.AddInstruction(
177 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
178 HloInstruction* gte2 = builder.AddInstruction(
179 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 2));
180 HloInstruction* tuple =
181 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
182
183 auto module = CreateNewVerifiedModule();
184 auto computation = module->AddEntryComputation(builder.Build());
185
186 EXPECT_THAT(computation->root_instruction(), tuple);
187
188 Run(module.get(), /*change_expected=*/true);
189
190 EXPECT_THAT(computation->root_instruction(), tuple_param);
191 }
192
TEST_F(TupleSimplifierTest,IncompatibleTuples)193 TEST_F(TupleSimplifierTest, IncompatibleTuples) {
194 // Verify that a tuple->GTE->tuple construct is not simplified if the input
195 // and output tuple are not compatible shapes.
196 HloComputation::Builder builder(TestName());
197 HloInstruction* tuple_param = builder.AddInstruction(
198 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
199 HloInstruction* gte0 = builder.AddInstruction(
200 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
201 HloInstruction* gte1 = builder.AddInstruction(
202 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
203 // Output tuple has only two elements. Parameter tuple has three elements so
204 // simplification is not possible.
205 HloInstruction* tuple =
206 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
207
208 auto module = CreateNewVerifiedModule();
209 auto computation = module->AddEntryComputation(builder.Build());
210
211 EXPECT_THAT(computation->root_instruction(), tuple);
212
213 Run(module.get(), /*change_expected=*/false);
214
215 EXPECT_THAT(computation->root_instruction(), tuple);
216 }
217
TEST_F(TupleSimplifierTest,CanExcludeEntryComputation)218 TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
219 // Verify that the root computation can be excluded
220 auto module = CreateNewVerifiedModule();
221
222 HloInstruction* p0;
223 HloInstruction* p1;
224 HloComputation* c0;
225 HloComputation* c1;
226 HloComputation* entry;
227
228 {
229 HloComputation::Builder builder(TestName() + "_1");
230 p0 = builder.AddInstruction(
231 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
232 HloInstruction* gte0 = builder.AddInstruction(
233 HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0));
234 HloInstruction* gte1 = builder.AddInstruction(
235 HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1));
236 HloInstruction* gte2 = builder.AddInstruction(
237 HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2));
238
239 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
240
241 c0 = module->AddEmbeddedComputation(builder.Build());
242 }
243 {
244 HloComputation::Builder builder(TestName() + "_2");
245 p1 = builder.AddInstruction(
246 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
247 HloInstruction* gte0 = builder.AddInstruction(
248 HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0));
249 HloInstruction* gte1 = builder.AddInstruction(
250 HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1));
251 HloInstruction* gte2 = builder.AddInstruction(
252 HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2));
253
254 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
255
256 c1 = module->AddEmbeddedComputation(builder.Build());
257 }
258 {
259 HloComputation::Builder builder(TestName() + "_Entry");
260 HloInstruction* tuple_param = builder.AddInstruction(
261 HloInstruction::CreateParameter(0, tuple_shape_, "param"));
262 HloInstruction* call0 = builder.AddInstruction(
263 HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0));
264 HloInstruction* call1 = builder.AddInstruction(
265 HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1));
266 HloInstruction* gte0 = builder.AddInstruction(
267 HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0));
268 HloInstruction* gte1 = builder.AddInstruction(
269 HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1));
270 HloInstruction* tuple0 =
271 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
272 HloInstruction* gte2 = builder.AddInstruction(
273 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0));
274 HloInstruction* gte3 = builder.AddInstruction(
275 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1));
276
277 builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3}));
278
279 entry = module->AddEntryComputation(builder.Build());
280 }
281
282 Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/true);
283
284 EXPECT_THAT(c0->root_instruction(), p0);
285 EXPECT_THAT(c1->root_instruction(), p1);
286 EXPECT_THAT(entry->instruction_count(), 9);
287 }
288
TEST_F(TupleSimplifierTest,ShardingLoss)289 TEST_F(TupleSimplifierTest, ShardingLoss) {
290 const char* kModuleStr = R"(
291 HloModule m
292
293 ENTRY test {
294 p0 = s32[10] parameter(0), sharding={devices=[2]0,1}
295 t = (s32[10]) tuple(p0)
296 ROOT %gte = s32[10] get-tuple-element(t), index=0, sharding={replicated}
297 }
298 )";
299 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
300 Run(m.get(), /*change_expected=*/false);
301 }
302
303 } // namespace
304 } // namespace xla
305