• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
17 
18 #include <algorithm>
19 #include <memory>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/tests/test_utils.h"
28 #include "tensorflow/compiler/xla/types.h"
29 
30 namespace xla {
31 namespace gpu {
32 
33 class GpuHloScheduleTest : public HloTestBase {
34  protected:
35   using HloVec = std::vector<HloInstruction*>;
36 
37   // Pre-canned shapes.
38   Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
39 
BuildHloOrdering(HloModule * module)40   static SequentialHloOrdering BuildHloOrdering(HloModule* module) {
41     HloSchedule schedule =
42         ScheduleGpuModule(module, /*pointer_size=*/8).value();
43     return SequentialHloOrdering{schedule};
44   }
45 
CreateNewVerifiedModule()46   std::unique_ptr<HloModule> CreateNewVerifiedModule() {
47     HloModuleConfig config;
48     auto debug_options = GetDebugOptionsForTest();
49     config.set_debug_options(debug_options);
50     return std::make_unique<HloModule>("test_module", config);
51   }
52 };
53 
54 // Test of a single stream, where data dependencies fully determine the
55 // execution order.
TEST_F(GpuHloScheduleTest,SequentialMatMul)56 TEST_F(GpuHloScheduleTest, SequentialMatMul) {
57   HloComputation::Builder builder("entry_computation");
58   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
59       /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
60   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
61       /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
62   HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
63       /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
64   HloInstruction* dot1 =
65       builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
66   HloInstruction* dot2 =
67       builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
68 
69   auto module = CreateNewVerifiedModule();
70   module->AddEntryComputation(builder.Build(dot2));
71 
72   SequentialHloOrdering order = BuildHloOrdering(module.get());
73   EXPECT_TRUE(order.ExecutesBefore(y, x));
74   EXPECT_TRUE(order.ExecutesBefore(y, dot1));
75   EXPECT_TRUE(order.ExecutesBefore(z, dot1));
76   EXPECT_TRUE(order.ExecutesBefore(z, dot2));
77   EXPECT_TRUE(order.ExecutesBefore(dot1, dot2));
78 }
79 
80 // Test of a single stream, where data dependencies do not fully determine the
81 // execution order, but the stream assignment does.
TEST_F(GpuHloScheduleTest,SequentialAdd)82 TEST_F(GpuHloScheduleTest, SequentialAdd) {
83   HloComputation::Builder builder("entry_computation");
84   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
85       /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
86   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
87       /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
88   HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
89       /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
90   HloInstruction* add1 = builder.AddInstruction(
91       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
92   HloInstruction* add2 = builder.AddInstruction(
93       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, y, z));
94   HloInstruction* add3 = builder.AddInstruction(
95       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
96 
97   auto module = CreateNewVerifiedModule();
98   module->AddEntryComputation(builder.Build(add3));
99 
100   SequentialHloOrdering order = BuildHloOrdering(module.get());
101   EXPECT_TRUE(order.ExecutesBefore(y, x));
102   EXPECT_TRUE(order.ExecutesBefore(y, add1));
103   EXPECT_TRUE(order.ExecutesBefore(z, add1));
104   EXPECT_TRUE(order.ExecutesBefore(z, add2));
105   EXPECT_TRUE(order.ExecutesBefore(add1, add2));
106   EXPECT_TRUE(order.ExecutesBefore(add2, add3));
107 }
108 
TEST_F(GpuHloScheduleTest,AsyncCustomCall)109 TEST_F(GpuHloScheduleTest, AsyncCustomCall) {
110   HloComputation::Builder builder("entry_computation");
111   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
112       /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
113   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
114       /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
115   HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
116       /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
117   HloInstruction* add0 = builder.AddInstruction(
118       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
119   HloInstruction* add1 = builder.AddInstruction(
120       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add0, y));
121   HloInstruction* add2 = builder.AddInstruction(
122       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, z));
123   // Create nonblocking_call(add0).
124   HloInstruction* nonblocking_call =
125       builder.AddInstruction(HloInstruction::CreateCustomCall(
126           f32_2x2_, {add0},
127           /*custom_call_target=*/"nonblocking-call-start",
128           /*opaque=*/""));
129   static_cast<HloCustomCallInstruction*>(nonblocking_call)
130       ->set_custom_call_schedule(SCHEDULE_EARLIEST);
131   // In addition, add control_dependency: add1->nonblocking_call.
132   TF_CHECK_OK(add1->AddControlDependencyTo(nonblocking_call));
133   // Blocking call, which only add4 depends on.
134   HloInstruction* blocking_call =
135       builder.AddInstruction(HloInstruction::CreateCustomCall(
136           f32_2x2_, {nonblocking_call},
137           /*custom_call_target=*/"blocking-call-done",
138           /*opaque=*/""));
139   static_cast<HloCustomCallInstruction*>(blocking_call)
140       ->set_custom_call_schedule(SCHEDULE_LATEST);
141   HloInstruction* add3 = builder.AddInstruction(
142       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
143   HloInstruction* add4 = builder.AddInstruction(HloInstruction::CreateBinary(
144       f32_2x2_, HloOpcode::kAdd, add3, blocking_call));
145 
146   auto module = CreateNewVerifiedModule();
147   module->AddEntryComputation(builder.Build(add4));
148 
149   SequentialHloOrdering order = BuildHloOrdering(module.get());
150   VLOG(2) << order.ToString();
151 
152   // Order constrained by data dependency.
153   EXPECT_TRUE(order.ExecutesBefore(add0, nonblocking_call));
154   // Order constrained by control dependency.
155   EXPECT_TRUE(order.ExecutesBefore(add1, nonblocking_call));
156   // Test that nonblocking_call is scheduled before add2, so that we know
157   // EARLIEST is in effect.
158   EXPECT_TRUE(order.ExecutesBefore(nonblocking_call, add2));
159   EXPECT_TRUE(order.ExecutesBefore(nonblocking_call, add3));
160   EXPECT_TRUE(order.ExecutesBefore(nonblocking_call, add4));
161 
162   // Test that blocking_call is scheduled after add3, so that we know
163   // LATEST is in effect.
164   EXPECT_TRUE(order.ExecutesBefore(add3, blocking_call));
165   EXPECT_TRUE(order.ExecutesBefore(blocking_call, add4));
166 }
167 
TEST_F(GpuHloScheduleTest,AsyncAllReduce)168 TEST_F(GpuHloScheduleTest, AsyncAllReduce) {
169   // All-reduce reduction computation.
170   HloComputation::Builder reduction_builder("add");
171   HloInstruction* x0 =
172       reduction_builder.AddInstruction(HloInstruction::CreateParameter(
173           /*parameter_number=*/0, ShapeUtil::MakeScalarShape(F32),
174           /*name=*/"x"));
175   HloInstruction* y0 =
176       reduction_builder.AddInstruction(HloInstruction::CreateParameter(
177           /*parameter_number=*/1, ShapeUtil::MakeScalarShape(F32),
178           /*name=*/"y"));
179   HloInstruction* add =
180       reduction_builder.AddInstruction(HloInstruction::CreateBinary(
181           ShapeUtil::MakeScalarShape(F32), HloOpcode::kAdd, x0, y0));
182 
183   std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
184   HloComputation* reduction_computation =
185       module->AddEmbeddedComputation(reduction_builder.Build(add));
186 
187   HloComputation::Builder builder("entry_computation");
188   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
189       /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
190   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
191       /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
192   HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
193       /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
194   HloInstruction* add0 = builder.AddInstruction(
195       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
196   HloInstruction* add1 = builder.AddInstruction(
197       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add0, y));
198   HloInstruction* add2 = builder.AddInstruction(
199       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, z));
200 
201   Shape all_reduce_start_shape =
202       ShapeUtil::MakeTupleShape({f32_2x2_, f32_2x2_});
203   HloInstruction* all_reduce_start =
204       builder.AddInstruction(HloInstruction::CreateAllReduceStart(
205           all_reduce_start_shape, {add0}, reduction_computation,
206           /*replica_groups=*/{}, /*constrain_layout=*/false,
207           /*channel_id=*/std::nullopt, /*use_global_device_ids=*/true));
208   // In addition, add control_dependency: add1->nonblocking_call.
209   TF_CHECK_OK(add1->AddControlDependencyTo(all_reduce_start));
210   // Blocking call, which only add4 depends on.
211   HloInstruction* all_reduce_done =
212       builder.AddInstruction(HloInstruction::CreateUnary(
213           f32_2x2_, HloOpcode::kAllReduceDone, all_reduce_start));
214   HloInstruction* add3 = builder.AddInstruction(
215       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
216   HloInstruction* add4 = builder.AddInstruction(HloInstruction::CreateBinary(
217       f32_2x2_, HloOpcode::kAdd, add3, all_reduce_done));
218 
219   module->AddEntryComputation(builder.Build(add4));
220 
221   SequentialHloOrdering order = BuildHloOrdering(module.get());
222   VLOG(2) << order.ToString();
223 
224   // Order constrained by data dependency.
225   EXPECT_TRUE(order.ExecutesBefore(add0, all_reduce_start));
226   // Order constrained by control dependency.
227   EXPECT_TRUE(order.ExecutesBefore(add1, all_reduce_start));
228   // Test that all_reduce_start is scheduled before add2.
229   EXPECT_TRUE(order.ExecutesBefore(all_reduce_start, add2));
230   EXPECT_TRUE(order.ExecutesBefore(all_reduce_start, add3));
231   EXPECT_TRUE(order.ExecutesBefore(all_reduce_start, add4));
232 
233   // Test that all_reduce_done is scheduled after add3.
234   EXPECT_TRUE(order.ExecutesBefore(add3, all_reduce_done));
235   EXPECT_TRUE(order.ExecutesBefore(all_reduce_done, add4));
236 }
237 
238 }  // namespace gpu
239 }  // namespace xla
240