• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
17 
18 #include <algorithm>
19 #include <unordered_set>
20 
21 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/types.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 class HloScheduleTest : public HloTestBase {
33  protected:
34   using HloVec = std::vector<const HloInstruction*>;
35 
36   // Pre-canned shapes.
37   Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
38 
BuildHloSchedule(const HloModule & module,const StreamAssignment & streams)39   static std::unique_ptr<HloSchedule> BuildHloSchedule(
40       const HloModule& module, const StreamAssignment& streams) {
41     return HloSchedule::Build(module, streams, /*pointer_size=*/8)
42         .ConsumeValueOrDie();
43   }
44 
RemoveHlo(const HloVec & input,const std::unordered_set<const HloInstruction * > & remove)45   HloVec RemoveHlo(const HloVec& input,
46                    const std::unordered_set<const HloInstruction*>& remove) {
47     HloVec result(input);
48     result.erase(std::remove_if(result.begin(), result.end(),
49                                 [&remove](const HloInstruction* x) {
50                                   return remove.count(x) > 0;
51                                 }),
52                  result.end());
53     return result;
54   }
55 };
56 
57 // Test of a single stream, where data dependencies fully determine the
58 // execution order.
TEST_F(HloScheduleTest,SequentialMatMul)59 TEST_F(HloScheduleTest, SequentialMatMul) {
60   HloComputation::Builder builder("entry_computation");
61   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
62       /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
63   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
64       /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
65   HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
66       /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
67   HloInstruction* dot1 = builder.AddInstruction(
68       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
69   HloInstruction* dot2 = builder.AddInstruction(
70       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z));
71 
72   auto module = CreateNewModule();
73   module->AddEntryComputation(builder.Build(dot2));
74 
75   std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
76   EXPECT_EQ(streams->StreamNumberForHlo(*dot1),
77             streams->StreamNumberForHlo(*dot2));
78 
79   auto schedule = BuildHloSchedule(*module, *streams);
80   // Remove parameters, which are unordered.
81   EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}),
82             HloVec({dot1, dot2}));
83 
84   // Parameters x,y,z are mutually unordered, while dot1 and dot2 are
85   // transitively ordered by operands.
86   auto order = schedule->ConsumeHloOrdering();
87   EXPECT_TRUE(order->ExecutesBefore(x, dot1));
88   EXPECT_TRUE(order->ExecutesBefore(x, dot2));
89   EXPECT_TRUE(order->ExecutesBefore(y, dot1));
90   EXPECT_TRUE(order->ExecutesBefore(y, dot2));
91   EXPECT_TRUE(order->ExecutesBefore(z, dot2));
92   EXPECT_TRUE(order->ExecutesBefore(dot1, dot2));
93 
94   EXPECT_FALSE(order->ExecutesBefore(x, x));
95   EXPECT_FALSE(order->ExecutesBefore(x, y));
96   EXPECT_FALSE(order->ExecutesBefore(x, z));
97   EXPECT_FALSE(order->ExecutesBefore(y, x));
98   EXPECT_FALSE(order->ExecutesBefore(y, y));
99   EXPECT_FALSE(order->ExecutesBefore(y, z));
100   EXPECT_FALSE(order->ExecutesBefore(z, x));
101   EXPECT_FALSE(order->ExecutesBefore(z, y));
102   EXPECT_FALSE(order->ExecutesBefore(z, z));
103   EXPECT_FALSE(order->ExecutesBefore(z, dot1));
104   EXPECT_FALSE(order->ExecutesBefore(dot1, x));
105   EXPECT_FALSE(order->ExecutesBefore(dot1, y));
106   EXPECT_FALSE(order->ExecutesBefore(dot1, z));
107   EXPECT_FALSE(order->ExecutesBefore(dot1, dot1));
108   EXPECT_FALSE(order->ExecutesBefore(dot2, x));
109   EXPECT_FALSE(order->ExecutesBefore(dot2, y));
110   EXPECT_FALSE(order->ExecutesBefore(dot2, z));
111   EXPECT_FALSE(order->ExecutesBefore(dot2, dot1));
112   EXPECT_FALSE(order->ExecutesBefore(dot2, dot2));
113 }
114 
115 // Test of a single stream, where data dependencies do not fully determine the
116 // execution order, but the stream assignment does.
TEST_F(HloScheduleTest,SequentialAdd)117 TEST_F(HloScheduleTest, SequentialAdd) {
118   HloComputation::Builder builder("entry_computation");
119   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
120       /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
121   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
122       /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
123   HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
124       /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
125   HloInstruction* add1 = builder.AddInstruction(
126       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
127   HloInstruction* add2 = builder.AddInstruction(
128       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, y, z));
129   HloInstruction* add3 = builder.AddInstruction(
130       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
131 
132   auto module = CreateNewModule();
133   module->AddEntryComputation(builder.Build(add3));
134 
135   std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
136   EXPECT_EQ(streams->StreamNumberForHlo(*add1),
137             streams->StreamNumberForHlo(*add2));
138   EXPECT_EQ(streams->StreamNumberForHlo(*add1),
139             streams->StreamNumberForHlo(*add3));
140 
141   auto schedule = BuildHloSchedule(*module, *streams);
142   // Remove parameters, which are unordered.
143   EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}),
144             HloVec({add1, add2, add3}));
145 
146   // Parameters x,y,z are mutually unordered, while add1, add2 and add3 are
147   // transitively ordered by operands.
148   auto order = schedule->ConsumeHloOrdering();
149   EXPECT_TRUE(order->ExecutesBefore(x, add1));
150   EXPECT_TRUE(order->ExecutesBefore(x, add2));
151   EXPECT_TRUE(order->ExecutesBefore(x, add3));
152   EXPECT_TRUE(order->ExecutesBefore(y, add1));
153   EXPECT_TRUE(order->ExecutesBefore(y, add2));
154   EXPECT_TRUE(order->ExecutesBefore(y, add3));
155   EXPECT_TRUE(order->ExecutesBefore(z, add2));
156   EXPECT_TRUE(order->ExecutesBefore(z, add3));
157   EXPECT_TRUE(order->ExecutesBefore(add1, add3));
158   EXPECT_TRUE(order->ExecutesBefore(add2, add3));
159   // The HLO graph does not define an ordering for add1 and add2, but their
160   // assignment onto the same stream does define an ordering.
161   if (order->ExecutesBefore(add1, add2)) {
162     EXPECT_FALSE(order->ExecutesBefore(add2, add1));
163   } else {
164     EXPECT_TRUE(order->ExecutesBefore(add2, add1));
165     EXPECT_FALSE(order->ExecutesBefore(add1, add2));
166   }
167 
168   EXPECT_FALSE(order->ExecutesBefore(x, x));
169   EXPECT_FALSE(order->ExecutesBefore(x, y));
170   EXPECT_FALSE(order->ExecutesBefore(x, z));
171   EXPECT_FALSE(order->ExecutesBefore(y, x));
172   EXPECT_FALSE(order->ExecutesBefore(y, y));
173   EXPECT_FALSE(order->ExecutesBefore(y, z));
174   EXPECT_FALSE(order->ExecutesBefore(z, x));
175   EXPECT_FALSE(order->ExecutesBefore(z, y));
176   EXPECT_FALSE(order->ExecutesBefore(z, z));
177   EXPECT_FALSE(order->ExecutesBefore(z, add1));
178   EXPECT_FALSE(order->ExecutesBefore(add1, x));
179   EXPECT_FALSE(order->ExecutesBefore(add1, y));
180   EXPECT_FALSE(order->ExecutesBefore(add1, z));
181   EXPECT_FALSE(order->ExecutesBefore(add1, add1));
182   EXPECT_FALSE(order->ExecutesBefore(add2, x));
183   EXPECT_FALSE(order->ExecutesBefore(add2, y));
184   EXPECT_FALSE(order->ExecutesBefore(add2, z));
185   EXPECT_FALSE(order->ExecutesBefore(add2, add2));
186 }
187 
188 // Test of two streams.
TEST_F(HloScheduleTest,ConcurrentMatMul)189 TEST_F(HloScheduleTest, ConcurrentMatMul) {
190   HloComputation::Builder builder("entry_computation");
191   HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
192       /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
193   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
194       /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
195   HloInstruction* dot1 = builder.AddInstruction(
196       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
197   HloInstruction* dot2 = builder.AddInstruction(
198       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x));
199   HloInstruction* add = builder.AddInstruction(
200       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
201 
202   auto module = CreateNewModule();
203   module->AddEntryComputation(builder.Build(add));
204 
205   std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
206   EXPECT_NE(streams->StreamNumberForHlo(*dot1),
207             streams->StreamNumberForHlo(*dot2));
208 
209   auto schedule = BuildHloSchedule(*module, *streams);
210   // Remove parameters, which are unordered.
211   HloVec thunk_launch_order = RemoveHlo(schedule->ThunkLaunchOrder(), {x, y});
212   EXPECT_TRUE(thunk_launch_order == HloVec({dot1, dot2, add}) ||
213               thunk_launch_order == HloVec({dot2, dot1, add}));
214 
215   // Parameters x,y are mutually unordered, while dot1, dot2 and add are
216   // transitively ordered by operands.
217   auto order = schedule->ConsumeHloOrdering();
218   EXPECT_TRUE(order->ExecutesBefore(x, dot1));
219   EXPECT_TRUE(order->ExecutesBefore(x, dot2));
220   EXPECT_TRUE(order->ExecutesBefore(y, dot1));
221   EXPECT_TRUE(order->ExecutesBefore(y, dot2));
222   EXPECT_TRUE(order->ExecutesBefore(dot1, add));
223   EXPECT_TRUE(order->ExecutesBefore(dot2, add));
224 
225   EXPECT_FALSE(order->ExecutesBefore(x, x));
226   EXPECT_FALSE(order->ExecutesBefore(x, y));
227   EXPECT_FALSE(order->ExecutesBefore(y, x));
228   EXPECT_FALSE(order->ExecutesBefore(y, y));
229   EXPECT_FALSE(order->ExecutesBefore(dot1, x));
230   EXPECT_FALSE(order->ExecutesBefore(dot1, y));
231   EXPECT_FALSE(order->ExecutesBefore(dot1, dot1));
232   EXPECT_FALSE(order->ExecutesBefore(dot1, dot2));
233   EXPECT_FALSE(order->ExecutesBefore(dot2, x));
234   EXPECT_FALSE(order->ExecutesBefore(dot2, y));
235   EXPECT_FALSE(order->ExecutesBefore(dot2, dot1));
236   EXPECT_FALSE(order->ExecutesBefore(dot2, dot2));
237   EXPECT_FALSE(order->ExecutesBefore(add, x));
238   EXPECT_FALSE(order->ExecutesBefore(add, y));
239   EXPECT_FALSE(order->ExecutesBefore(add, dot1));
240   EXPECT_FALSE(order->ExecutesBefore(add, dot2));
241   EXPECT_FALSE(order->ExecutesBefore(add, add));
242 }
243 
244 // Test of multiple streams.
TEST_F(HloScheduleTest,LatticeMatMul)245 TEST_F(HloScheduleTest, LatticeMatMul) {
246   //      d00      -- layer 0
247   //     /   \
248   //   d10   d11   -- layer 1
249   //  /   \ /   \
250   // d20  d21  d22 -- layer 2
251   //  \   / \   /
252   //   d30   d31   -- layer 3
253   //     \   /
254   //      d40      -- layer 4
255   HloComputation::Builder builder("entry_computation");
256   std::vector<HloInstruction*> params;
257   params.reserve(6);
258   for (int i = 0; i < 6; ++i) {
259     params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
260         i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
261   }
262   HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary(
263       f32_2x2_, HloOpcode::kDot, params[2], params[3]));
264   HloInstruction* d10 = builder.AddInstruction(
265       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00));
266   HloInstruction* d11 = builder.AddInstruction(
267       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4]));
268   HloInstruction* d20 = builder.AddInstruction(
269       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10));
270   HloInstruction* d21 = builder.AddInstruction(
271       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11));
272   HloInstruction* d22 = builder.AddInstruction(
273       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5]));
274   HloInstruction* d30 = builder.AddInstruction(
275       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21));
276   HloInstruction* d31 = builder.AddInstruction(
277       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22));
278   HloInstruction* d40 = builder.AddInstruction(
279       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31));
280 
281   auto module = CreateNewModule();
282   module->AddEntryComputation(builder.Build(d40));
283 
284   std::unique_ptr<StreamAssignment> streams = AssignStreams(*module);
285   // The two dots on layer 1 are concurrent.
286   EXPECT_NE(streams->StreamNumberForHlo(*d10),
287             streams->StreamNumberForHlo(*d11));
288   // The three dots on layer 2 are concurrent.
289   EXPECT_NE(streams->StreamNumberForHlo(*d20),
290             streams->StreamNumberForHlo(*d21));
291   EXPECT_NE(streams->StreamNumberForHlo(*d20),
292             streams->StreamNumberForHlo(*d22));
293   EXPECT_NE(streams->StreamNumberForHlo(*d21),
294             streams->StreamNumberForHlo(*d22));
295   // The two dots on layer 3 are concurrent.
296   EXPECT_NE(streams->StreamNumberForHlo(*d30),
297             streams->StreamNumberForHlo(*d31));
298 
299   // We don't check the thunk launch order, since there are many valid total
300   // orders, and it's annoying to express.
301   auto schedule = BuildHloSchedule(*module, *streams);
302 
303   auto order = schedule->ConsumeHloOrdering();
304   const HloVec all_params(
305       {params[0], params[1], params[2], params[3], params[4], params[5]});
306   const HloVec all_ops({d00, d10, d11, d20, d21, d22, d30, d31, d40});
307 
308   // Parameters are mutually unordered, and never execute before ops.
309   for (const HloInstruction* param : all_params) {
310     for (const HloInstruction* param2 : all_params) {
311       EXPECT_FALSE(order->ExecutesBefore(param, param2));
312     }
313     for (const HloInstruction* op : all_ops) {
314       EXPECT_FALSE(order->ExecutesBefore(op, param));
315     }
316   }
317 
318   // Check ordering of params before ops.
319   for (const HloInstruction* op : all_ops) {
320     if (op == d20 || op == d30 || op == d40) {
321       EXPECT_TRUE(order->ExecutesBefore(params[0], op));
322     } else {
323       EXPECT_FALSE(order->ExecutesBefore(params[0], op));
324     }
325     if (op != d00 && op != d11 && op != d22) {
326       EXPECT_TRUE(order->ExecutesBefore(params[1], op));
327     } else {
328       EXPECT_FALSE(order->ExecutesBefore(params[1], op));
329     }
330     EXPECT_TRUE(order->ExecutesBefore(params[2], op));
331     EXPECT_TRUE(order->ExecutesBefore(params[3], op));
332     if (op != d00 && op != d10 && op != d20) {
333       EXPECT_TRUE(order->ExecutesBefore(params[4], op));
334     } else {
335       EXPECT_FALSE(order->ExecutesBefore(params[4], op));
336     }
337     if (op == d22 || op == d31 || op == d40) {
338       EXPECT_TRUE(order->ExecutesBefore(params[5], op));
339     } else {
340       EXPECT_FALSE(order->ExecutesBefore(params[5], op));
341     }
342   }
343 
344   // Check ordering of ops before ops.
345   for (const HloInstruction* op : all_ops) {
346     if (op != d00) {
347       EXPECT_TRUE(order->ExecutesBefore(d00, op));
348     } else {
349       EXPECT_FALSE(order->ExecutesBefore(d00, op));
350     }
351 
352     if (op == d20 || op == d21 || op == d30 || op == d31 || op == d40) {
353       EXPECT_TRUE(order->ExecutesBefore(d10, op));
354     } else {
355       EXPECT_FALSE(order->ExecutesBefore(d10, op));
356     }
357 
358     if (op == d21 || op == d22 || op == d30 || op == d31 || op == d40) {
359       EXPECT_TRUE(order->ExecutesBefore(d11, op));
360     } else {
361       EXPECT_FALSE(order->ExecutesBefore(d11, op));
362     }
363 
364     if (op == d30 || op == d40) {
365       EXPECT_TRUE(order->ExecutesBefore(d20, op));
366     } else {
367       EXPECT_FALSE(order->ExecutesBefore(d20, op));
368     }
369 
370     if (op == d30 || op == d31 || op == d40) {
371       EXPECT_TRUE(order->ExecutesBefore(d21, op));
372     } else {
373       EXPECT_FALSE(order->ExecutesBefore(d21, op));
374     }
375 
376     if (op == d31 || op == d40) {
377       EXPECT_TRUE(order->ExecutesBefore(d22, op));
378     } else {
379       EXPECT_FALSE(order->ExecutesBefore(d22, op));
380     }
381 
382     if (op == d40) {
383       EXPECT_TRUE(order->ExecutesBefore(d30, op));
384       EXPECT_TRUE(order->ExecutesBefore(d31, op));
385     } else {
386       EXPECT_FALSE(order->ExecutesBefore(d30, op));
387       EXPECT_FALSE(order->ExecutesBefore(d31, op));
388     }
389 
390     EXPECT_FALSE(order->ExecutesBefore(d40, op));
391   }
392 }
393 
394 }  // namespace gpu
395 }  // namespace xla
396