1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
17
18 #include <algorithm>
19 #include <memory>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/tests/test_utils.h"
28 #include "tensorflow/compiler/xla/types.h"
29
30 namespace xla {
31 namespace gpu {
32
33 class GpuHloScheduleTest : public HloTestBase {
34 protected:
35 using HloVec = std::vector<HloInstruction*>;
36
37 // Pre-canned shapes.
38 Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
39
BuildHloOrdering(HloModule * module)40 static SequentialHloOrdering BuildHloOrdering(HloModule* module) {
41 HloSchedule schedule =
42 ScheduleGpuModule(module, /*pointer_size=*/8).value();
43 return SequentialHloOrdering{schedule};
44 }
45
CreateNewVerifiedModule()46 std::unique_ptr<HloModule> CreateNewVerifiedModule() {
47 HloModuleConfig config;
48 auto debug_options = GetDebugOptionsForTest();
49 config.set_debug_options(debug_options);
50 return std::make_unique<HloModule>("test_module", config);
51 }
52 };
53
54 // Test of a single stream, where data dependencies fully determine the
55 // execution order.
TEST_F(GpuHloScheduleTest,SequentialMatMul)56 TEST_F(GpuHloScheduleTest, SequentialMatMul) {
57 HloComputation::Builder builder("entry_computation");
58 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
59 /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
60 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
61 /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
62 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
63 /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
64 HloInstruction* dot1 =
65 builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
66 HloInstruction* dot2 =
67 builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
68
69 auto module = CreateNewVerifiedModule();
70 module->AddEntryComputation(builder.Build(dot2));
71
72 SequentialHloOrdering order = BuildHloOrdering(module.get());
73 EXPECT_TRUE(order.ExecutesBefore(y, x));
74 EXPECT_TRUE(order.ExecutesBefore(y, dot1));
75 EXPECT_TRUE(order.ExecutesBefore(z, dot1));
76 EXPECT_TRUE(order.ExecutesBefore(z, dot2));
77 EXPECT_TRUE(order.ExecutesBefore(dot1, dot2));
78 }
79
80 // Test of a single stream, where data dependencies do not fully determine the
81 // execution order, but the stream assignment does.
TEST_F(GpuHloScheduleTest,SequentialAdd)82 TEST_F(GpuHloScheduleTest, SequentialAdd) {
83 HloComputation::Builder builder("entry_computation");
84 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
85 /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
86 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
87 /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
88 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
89 /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
90 HloInstruction* add1 = builder.AddInstruction(
91 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
92 HloInstruction* add2 = builder.AddInstruction(
93 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, y, z));
94 HloInstruction* add3 = builder.AddInstruction(
95 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
96
97 auto module = CreateNewVerifiedModule();
98 module->AddEntryComputation(builder.Build(add3));
99
100 SequentialHloOrdering order = BuildHloOrdering(module.get());
101 EXPECT_TRUE(order.ExecutesBefore(y, x));
102 EXPECT_TRUE(order.ExecutesBefore(y, add1));
103 EXPECT_TRUE(order.ExecutesBefore(z, add1));
104 EXPECT_TRUE(order.ExecutesBefore(z, add2));
105 EXPECT_TRUE(order.ExecutesBefore(add1, add2));
106 EXPECT_TRUE(order.ExecutesBefore(add2, add3));
107 }
108
TEST_F(GpuHloScheduleTest,AsyncCustomCall)109 TEST_F(GpuHloScheduleTest, AsyncCustomCall) {
110 HloComputation::Builder builder("entry_computation");
111 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
112 /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
113 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
114 /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
115 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
116 /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
117 HloInstruction* add0 = builder.AddInstruction(
118 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
119 HloInstruction* add1 = builder.AddInstruction(
120 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add0, y));
121 HloInstruction* add2 = builder.AddInstruction(
122 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, z));
123 // Create nonblocking_call(add0).
124 HloInstruction* nonblocking_call =
125 builder.AddInstruction(HloInstruction::CreateCustomCall(
126 f32_2x2_, {add0},
127 /*custom_call_target=*/"nonblocking-call-start",
128 /*opaque=*/""));
129 static_cast<HloCustomCallInstruction*>(nonblocking_call)
130 ->set_custom_call_schedule(SCHEDULE_EARLIEST);
131 // In addition, add control_dependency: add1->nonblocking_call.
132 TF_CHECK_OK(add1->AddControlDependencyTo(nonblocking_call));
133 // Blocking call, which only add4 depends on.
134 HloInstruction* blocking_call =
135 builder.AddInstruction(HloInstruction::CreateCustomCall(
136 f32_2x2_, {nonblocking_call},
137 /*custom_call_target=*/"blocking-call-done",
138 /*opaque=*/""));
139 static_cast<HloCustomCallInstruction*>(blocking_call)
140 ->set_custom_call_schedule(SCHEDULE_LATEST);
141 HloInstruction* add3 = builder.AddInstruction(
142 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
143 HloInstruction* add4 = builder.AddInstruction(HloInstruction::CreateBinary(
144 f32_2x2_, HloOpcode::kAdd, add3, blocking_call));
145
146 auto module = CreateNewVerifiedModule();
147 module->AddEntryComputation(builder.Build(add4));
148
149 SequentialHloOrdering order = BuildHloOrdering(module.get());
150 VLOG(2) << order.ToString();
151
152 // Order constrained by data dependency.
153 EXPECT_TRUE(order.ExecutesBefore(add0, nonblocking_call));
154 // Order constrained by control dependency.
155 EXPECT_TRUE(order.ExecutesBefore(add1, nonblocking_call));
156 // Test that nonblocking_call is scheduled before add2, so that we know
157 // EARLIEST is in effect.
158 EXPECT_TRUE(order.ExecutesBefore(nonblocking_call, add2));
159 EXPECT_TRUE(order.ExecutesBefore(nonblocking_call, add3));
160 EXPECT_TRUE(order.ExecutesBefore(nonblocking_call, add4));
161
162 // Test that blocking_call is scheduled after add3, so that we know
163 // LATEST is in effect.
164 EXPECT_TRUE(order.ExecutesBefore(add3, blocking_call));
165 EXPECT_TRUE(order.ExecutesBefore(blocking_call, add4));
166 }
167
TEST_F(GpuHloScheduleTest,AsyncAllReduce)168 TEST_F(GpuHloScheduleTest, AsyncAllReduce) {
169 // All-reduce reduction computation.
170 HloComputation::Builder reduction_builder("add");
171 HloInstruction* x0 =
172 reduction_builder.AddInstruction(HloInstruction::CreateParameter(
173 /*parameter_number=*/0, ShapeUtil::MakeScalarShape(F32),
174 /*name=*/"x"));
175 HloInstruction* y0 =
176 reduction_builder.AddInstruction(HloInstruction::CreateParameter(
177 /*parameter_number=*/1, ShapeUtil::MakeScalarShape(F32),
178 /*name=*/"y"));
179 HloInstruction* add =
180 reduction_builder.AddInstruction(HloInstruction::CreateBinary(
181 ShapeUtil::MakeScalarShape(F32), HloOpcode::kAdd, x0, y0));
182
183 std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
184 HloComputation* reduction_computation =
185 module->AddEmbeddedComputation(reduction_builder.Build(add));
186
187 HloComputation::Builder builder("entry_computation");
188 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
189 /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
190 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
191 /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
192 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
193 /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
194 HloInstruction* add0 = builder.AddInstruction(
195 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
196 HloInstruction* add1 = builder.AddInstruction(
197 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add0, y));
198 HloInstruction* add2 = builder.AddInstruction(
199 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, z));
200
201 Shape all_reduce_start_shape =
202 ShapeUtil::MakeTupleShape({f32_2x2_, f32_2x2_});
203 HloInstruction* all_reduce_start =
204 builder.AddInstruction(HloInstruction::CreateAllReduceStart(
205 all_reduce_start_shape, {add0}, reduction_computation,
206 /*replica_groups=*/{}, /*constrain_layout=*/false,
207 /*channel_id=*/std::nullopt, /*use_global_device_ids=*/true));
208 // In addition, add control_dependency: add1->nonblocking_call.
209 TF_CHECK_OK(add1->AddControlDependencyTo(all_reduce_start));
210 // Blocking call, which only add4 depends on.
211 HloInstruction* all_reduce_done =
212 builder.AddInstruction(HloInstruction::CreateUnary(
213 f32_2x2_, HloOpcode::kAllReduceDone, all_reduce_start));
214 HloInstruction* add3 = builder.AddInstruction(
215 HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
216 HloInstruction* add4 = builder.AddInstruction(HloInstruction::CreateBinary(
217 f32_2x2_, HloOpcode::kAdd, add3, all_reduce_done));
218
219 module->AddEntryComputation(builder.Build(add4));
220
221 SequentialHloOrdering order = BuildHloOrdering(module.get());
222 VLOG(2) << order.ToString();
223
224 // Order constrained by data dependency.
225 EXPECT_TRUE(order.ExecutesBefore(add0, all_reduce_start));
226 // Order constrained by control dependency.
227 EXPECT_TRUE(order.ExecutesBefore(add1, all_reduce_start));
228 // Test that all_reduce_start is scheduled before add2.
229 EXPECT_TRUE(order.ExecutesBefore(all_reduce_start, add2));
230 EXPECT_TRUE(order.ExecutesBefore(all_reduce_start, add3));
231 EXPECT_TRUE(order.ExecutesBefore(all_reduce_start, add4));
232
233 // Test that all_reduce_done is scheduled after add3.
234 EXPECT_TRUE(order.ExecutesBefore(add3, all_reduce_done));
235 EXPECT_TRUE(order.ExecutesBefore(all_reduce_done, add4));
236 }
237
238 } // namespace gpu
239 } // namespace xla
240