1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
17
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_dce.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34
35 namespace xla {
36 namespace {
37
38 class HloScheduleTest : public HloTestBase {};
39
TEST_F(HloScheduleTest,UpdateScheduleUnchangedModule)40 TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) {
41 // Updating the schedule of an unchanged HLO module should not affect the
42 // schedule at all.
43 const std::string module_str = R"(
44 HloModule UpdateScheduleUnchanged
45
46 ENTRY main {
47 a = f32[] parameter(0)
48 b = f32[] parameter(1)
49 c = f32[] constant(42.0)
50 sum = f32[] add(a, b)
51 neg = f32[] negate(c)
52 ROOT root = f32[] multiply(sum, neg)
53 }
54 )";
55 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
56 ParseAndReturnVerifiedModule(module_str));
57 TF_ASSERT_OK_AND_ASSIGN(
58 HloSchedule schedule,
59 ScheduleModule(module.get(), [](const BufferValue& buffer) {
60 return ShapeUtil::ByteSizeOf(buffer.shape());
61 }));
62 const auto& entry_schedule =
63 schedule.sequence(module->entry_computation()).instructions();
64
65 EXPECT_EQ(entry_schedule.size(), 6);
66
67 TF_ASSERT_OK(schedule.Update());
68 TF_ASSERT_OK(schedule.Verify());
69
70 EXPECT_EQ(entry_schedule,
71 schedule.sequence(module->entry_computation()).instructions());
72 }
73
TEST_F(HloScheduleTest,UpdateScheduleWithNewInstructions)74 TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) {
75 // Add some additional instructions to a module and verify the schedule can be
76 // updated.
77 const std::string module_str = R"(
78 HloModule UpdateScheduleWithNewInstructions
79
80 ENTRY main {
81 a = f32[] parameter(0)
82 b = f32[] parameter(1)
83 c = f32[] constant(42.0)
84 sum = f32[] add(a, b)
85 neg = f32[] negate(c)
86 ROOT root = f32[] multiply(sum, neg)
87 }
88 )";
89 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
90 ParseAndReturnVerifiedModule(module_str));
91 TF_ASSERT_OK_AND_ASSIGN(
92 HloSchedule schedule,
93 ScheduleModule(module.get(), [](const BufferValue& buffer) {
94 return ShapeUtil::ByteSizeOf(buffer.shape());
95 }));
96
97 HloComputation* entry = module->entry_computation();
98 const Shape shape = entry->root_instruction()->shape();
99 HloInstruction* constant = entry->AddInstruction(
100 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
101 HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
102 shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
103 entry->set_root_instruction(sub);
104
105 auto in_schedule = [&](const HloInstruction* hlo) {
106 return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo);
107 };
108
109 EXPECT_EQ(schedule.sequence(entry).size(), 6);
110 EXPECT_FALSE(in_schedule(constant));
111 EXPECT_FALSE(in_schedule(sub));
112
113 ASSERT_IS_NOT_OK(schedule.Verify());
114 TF_ASSERT_OK(schedule.Update());
115 TF_ASSERT_OK(schedule.Verify());
116
117 EXPECT_EQ(schedule.sequence(entry).size(), 8);
118 EXPECT_TRUE(in_schedule(constant));
119 EXPECT_TRUE(in_schedule(sub));
120 }
121
TEST_F(HloScheduleTest,UpdateScheduleWithAddedAndDeletedInstruction)122 TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) {
123 // Add and delete some instructions from a module and verify that the schedule
124 // can be updated successfully.
125 const std::string module_str = R"(
126 HloModule UpdateScheduleWithAddedAndDeletedInstruction
127
128 ENTRY main {
129 a = f32[] parameter(0)
130 b = f32[] parameter(1)
131 c = f32[] constant(42.0)
132 sum = f32[] add(a, b)
133 neg = f32[] negate(c)
134 ROOT root = f32[] multiply(sum, neg)
135 }
136 )";
137
138 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
139 ParseAndReturnVerifiedModule(module_str));
140 TF_ASSERT_OK_AND_ASSIGN(
141 HloSchedule schedule,
142 ScheduleModule(module.get(), [](const BufferValue& buffer) {
143 return ShapeUtil::ByteSizeOf(buffer.shape());
144 }));
145
146 // Set the entry root to some expression containing just a parameter and a
147 // constant.
148 HloComputation* entry = module->entry_computation();
149 HloInstruction* constant = entry->AddInstruction(
150 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
151 HloInstruction* new_root = entry->AddInstruction(
152 HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
153 constant, entry->parameter_instruction(0)));
154 entry->set_root_instruction(new_root);
155
156 // DCE should remove everything but the parameters and the newly added code.
157 HloDCE dce;
158 TF_ASSERT_OK(dce.Run(module.get()).status());
159
160 EXPECT_EQ(schedule.sequence(entry).size(), 6);
161
162 ASSERT_IS_NOT_OK(schedule.Verify());
163 TF_ASSERT_OK(schedule.Update());
164 TF_ASSERT_OK(schedule.Verify());
165
166 EXPECT_EQ(schedule.sequence(entry).size(), 4);
167 }
168
TEST_F(HloScheduleTest,UpdateScheduleWithCompletelyReplacedModule)169 TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) {
170 // Completely replace a module with an entirely new set of instructions and
171 // verify that the schedule can be updated successfully.
172 const std::string module_str = R"(
173 HloModule UpdateScheduleWithCompletelyReplacedModule
174
175 ENTRY main {
176 a = f32[] constant(42.0)
177 b = f32[] constant(123.0)
178 ROOT sum = f32[] add(a, b)
179 }
180 )";
181
182 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
183 ParseAndReturnVerifiedModule(module_str));
184 TF_ASSERT_OK_AND_ASSIGN(
185 HloSchedule schedule,
186 ScheduleModule(module.get(), [](const BufferValue& buffer) {
187 return ShapeUtil::ByteSizeOf(buffer.shape());
188 }));
189
190 // Replace the entry computation with the negation of a constant.
191 HloComputation* entry = module->entry_computation();
192 HloInstruction* constant = entry->AddInstruction(
193 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
194 HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
195 constant->shape(), HloOpcode::kNegate, constant));
196 entry->set_root_instruction(new_root);
197
198 // DCE the old instructions.
199 HloDCE dce;
200 TF_ASSERT_OK(dce.Run(module.get()).status());
201
202 EXPECT_EQ(schedule.sequence(entry).size(), 3);
203
204 ASSERT_IS_NOT_OK(schedule.Verify());
205 TF_ASSERT_OK(schedule.Update());
206 TF_ASSERT_OK(schedule.Verify());
207
208 EXPECT_EQ(schedule.sequence(entry).size(), 2);
209 }
210
TEST_F(HloScheduleTest,UpdateScheduleWithMultipleComputations)211 TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) {
212 // Create changes to more than one computation in an HLO module and verify
213 // that the schedule can be updated.
214 const std::string module_str = R"(
215 HloModule UpdateScheduleWithMultipleComputations
216
217 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
218 %param.1 = (s32[], token[]) parameter(0)
219 %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
220 %constant.1 = s32[] constant(1)
221 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
222 %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
223 %after-all = token[] after-all(token[] %get-tuple-element.2)
224 ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
225 }
226
227 %Cond (param: (s32[], token[])) -> pred[] {
228 %param = (s32[], token[]) parameter(0)
229 %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
230 %constant = s32[] constant(42)
231 ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
232 }
233
234 ENTRY %WhileLoop () -> s32[] {
235 %zero = s32[] constant(0)
236 %init_token = token[] after-all()
237 %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
238 %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
239 ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
240 }
241 )";
242
243 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
244 ParseAndReturnVerifiedModule(module_str));
245 TF_ASSERT_OK_AND_ASSIGN(
246 HloSchedule schedule,
247 ScheduleModule(module.get(), [](const BufferValue& buffer) {
248 return ShapeUtil::ByteSizeOf(buffer.shape(),
249 /*pointer_size=*/sizeof(void*));
250 }));
251
252 const HloInstruction* xla_while =
253 module->entry_computation()->root_instruction()->operand(0);
254 HloComputation* body = xla_while->while_body();
255 HloComputation* cond = xla_while->while_condition();
256
257 // Negate the root of the cond.
258 cond->set_root_instruction(cond->AddInstruction(
259 HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
260 HloOpcode::kNot, cond->root_instruction())));
261
262 // Replace the body with a computation which just passes through its
263 // parameter.
264 body->set_root_instruction(body->parameter_instruction(0));
265
266 // DCE the dead code in the body.
267 HloDCE dce;
268 TF_ASSERT_OK(dce.Run(module.get()).status());
269
270 EXPECT_EQ(schedule.sequence(body).size(), 7);
271 EXPECT_EQ(schedule.sequence(cond).size(), 4);
272
273 ASSERT_IS_NOT_OK(schedule.Verify());
274 TF_ASSERT_OK(schedule.Update());
275 TF_ASSERT_OK(schedule.Verify());
276
277 EXPECT_EQ(schedule.sequence(body).size(), 1);
278 EXPECT_EQ(schedule.sequence(cond).size(), 5);
279 }
280
TEST_F(HloScheduleTest,UpdateScheduleComputationRemoved)281 TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) {
282 // Remove computations from a module and verify the schedule can be updated.
283 const std::string module_str = R"(
284 HloModule UpdateScheduleWithMultipleComputations
285
286 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
287 %param.1 = (s32[], token[]) parameter(0)
288 %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
289 %constant.1 = s32[] constant(1)
290 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
291 %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
292 %after-all = token[] after-all(token[] %get-tuple-element.2)
293 ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
294 }
295
296 %Cond (param: (s32[], token[])) -> pred[] {
297 %param = (s32[], token[]) parameter(0)
298 %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
299 %constant = s32[] constant(42)
300 ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
301 }
302
303 ENTRY %WhileLoop () -> s32[] {
304 %zero = s32[] constant(0)
305 %init_token = token[] after-all()
306 %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
307 %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
308 ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
309 }
310 )";
311
312 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
313 ParseAndReturnVerifiedModule(module_str));
314 TF_ASSERT_OK_AND_ASSIGN(
315 HloSchedule schedule,
316 ScheduleModule(module.get(), [](const BufferValue& buffer) {
317 return ShapeUtil::ByteSizeOf(buffer.shape(),
318 /*pointer_size=*/sizeof(void*));
319 }));
320
321 HloInstruction* xla_while =
322 module->entry_computation()->root_instruction()->mutable_operand(0);
323 HloInstruction* init = xla_while->mutable_operand(0);
324
325 // Replace the while with its init value. The conditional and body
326 // computations should then be dead.
327 TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init));
328
329 // DCE the dead code in the body.
330 HloDCE dce;
331 ASSERT_EQ(module->computation_count(), 3);
332 TF_ASSERT_OK(dce.Run(module.get()).status());
333 ASSERT_EQ(module->computation_count(), 1);
334
335 ASSERT_IS_NOT_OK(schedule.Verify());
336 TF_ASSERT_OK(schedule.Update());
337 TF_ASSERT_OK(schedule.Verify());
338 }
339
TEST_F(HloScheduleTest,UpdateScheduleComputationRemovedWithMultiThreads)340 TEST_F(HloScheduleTest, UpdateScheduleComputationRemovedWithMultiThreads) {
341 // Remove computations from a module main thread and verify the schedule can
342 // be updated while the other threads are remaining unchanged.
343 const std::string module_str = R"(
344 HloModule UpdateScheduleWithMultipleComputations
345
346 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
347 %param.1 = (s32[], token[]) parameter(0)
348 %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
349 %constant.1 = s32[] constant(1)
350 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
351 %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
352 %after-all = token[] after-all(token[] %get-tuple-element.2)
353 ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
354 }
355
356 %Cond (param: (s32[], token[])) -> pred[] {
357 %param = (s32[], token[]) parameter(0)
358 %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
359 %constant = s32[] constant(42)
360 ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
361 }
362
363 %async_builder {
364 %p0 = f32[10] parameter(0)
365 %p1 = f32[10] parameter(1)
366 ROOT %foo = add(%p0, %p1)
367 }, execution_thread="parallel_thread"
368
369 ENTRY %WhileLoop () -> (s32[], f32[10]) {
370 %p0 = f32[10] parameter(0)
371 %p1 = f32[10] parameter(1)
372 %zero = s32[] constant(0)
373 %init_token = token[] after-all()
374 %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
375 %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
376 %async-start = ((f32[10], f32[10]), f32[10], s32[]) async-start(f32[10] %p0, f32[10] %p1), async_execution_thread="parallel_thread",calls=%async_builder
377 %async-done = f32[10]{0} async-done(((f32[10], f32[10]), f32[10], s32[]) %async-start), async_execution_thread="parallel_thread", calls=%async_builder
378 %main_res = s32[] get-tuple-element((s32[], token[]) %while), index=0
379 ROOT %res = tuple(%main_res, %async-done)
380 }
381 )";
382
383 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
384 ParseAndReturnVerifiedModule(module_str));
385 TF_ASSERT_OK_AND_ASSIGN(
386 HloSchedule schedule,
387 ScheduleModule(module.get(),
388 [](const BufferValue& buffer) {
389 return ShapeUtil::ByteSizeOf(
390 buffer.shape(),
391 /*pointer_size=*/sizeof(void*));
392 },
393 /*algorithm=*/{}, {HloInstruction::kMainExecutionThread}));
394
395 HloInstruction* xla_while = module->entry_computation()
396 ->root_instruction()
397 ->mutable_operand(0)
398 ->mutable_operand(0);
399 HloInstruction* init = xla_while->mutable_operand(0);
400
401 // Replace the while with its init value. The conditional and body
402 // computations should then be dead.
403 TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init));
404
405 // DCE the dead code in the body.
406 HloDCE dce;
407 ASSERT_EQ(module->computation_count(), 4);
408 TF_ASSERT_OK(dce.Run(module.get()).status());
409 ASSERT_EQ(module->computation_count(), 2);
410
411 ASSERT_IS_NOT_OK(schedule.Verify());
412 TF_ASSERT_OK(schedule.Update({HloInstruction::kMainExecutionThread}));
413 TF_ASSERT_OK(schedule.Verify());
414
415 ASSERT_EQ(module->MakeNonfusionComputations({"parallel_thread"}).size(), 1);
416 ASSERT_FALSE(schedule.is_computation_scheduled(
417 module->MakeNonfusionComputations({"parallel_thread"}).front()));
418 }
419 } // namespace
420 } // namespace xla
421