1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/xla/map_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
24
25 namespace xla {
26 namespace gpu {
27
HasStreamAssigned(const HloInstruction & hlo) const28 bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const {
29 return hlo_to_stream_number_.contains(&hlo);
30 }
31
StreamNumberForHlo(const HloInstruction & hlo) const32 int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const {
33 return FindOrDie(hlo_to_stream_number_, &hlo);
34 }
35
AssignStreamToHlo(const HloInstruction * hlo,int stream_num)36 void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo,
37 int stream_num) {
38 CHECK_GE(stream_num, 0);
39 if (stream_num >= stream_count_) {
40 stream_count_ = stream_num + 1;
41 }
42 InsertOrDie(&hlo_to_stream_number_, hlo, stream_num);
43 VLOG(2) << "Assign stream #" << stream_num << " to " << hlo->ToString();
44 }
45
46 namespace {
47
48 // Returns whether the two HLOs can run concurrently, i.e., neither is a
49 // transitive consumer of the other.
CanRunConcurrently(const HloInstruction & a,const HloInstruction & b,const HloReachabilityMap & reachability)50 bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b,
51 const HloReachabilityMap& reachability) {
52 return !reachability.IsConnected(&a, &b);
53 }
54
55 constexpr int kInvalidStreamNum = -1;
56 // Returns true iff `stream_num` is an invalid stream number.
IsStreamNumValid(int stream_num)57 inline bool IsStreamNumValid(int stream_num) {
58 return stream_num != kInvalidStreamNum;
59 }
60
61 // Returns which existing stream to assign to `hlo`, or -1 if a stream is not
62 // needed. `stream_assignment` is the existing stream assignment for all
63 // instructions topologically before `hlo`. `seen_gemms` contains all GEMMs that
64 // are topologically before `hlo`.
ComputeStreamToAssign(const HloInstruction & hlo,const StreamAssignment & stream_assignment,const HloReachabilityMap & reachability,const std::vector<const HloInstruction * > & seen_gemms)65 int ComputeStreamToAssign(
66 const HloInstruction& hlo, const StreamAssignment& stream_assignment,
67 const HloReachabilityMap& reachability,
68 const std::vector<const HloInstruction*>& seen_gemms) {
69 if (hlo.opcode() == HloOpcode::kParameter ||
70 hlo.opcode() == HloOpcode::kConstant) {
71 // kParameter and kConstant do not need a thunk.
72 return kInvalidStreamNum;
73 }
74
75 if (hlo.GetModule()
76 ->config()
77 .debug_options()
78 .xla_gpu_disable_multi_streaming()) {
79 return 0;
80 }
81
82 if (!ImplementedAsGemm(hlo)) {
83 // If `hlo` is not implemented as a GEMM, keep it close to its operands to
84 // avoid excessive synchronization.
85 int stream_num = -1;
86 for (const auto* operand : hlo.operands()) {
87 if (stream_assignment.HasStreamAssigned(*operand)) {
88 stream_num = std::max(stream_num,
89 stream_assignment.StreamNumberForHlo(*operand));
90 }
91 }
92 if (!IsStreamNumValid(stream_num)) {
93 stream_num = 0;
94 }
95 return stream_num;
96 }
97
98 // Assign different streams to concurrent GEMMs. The code below uses a
99 // greedy approach. First, we compute as forbidden_stream_numbers the
100 // streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign
101 // `hlo` a different stream.
102 absl::flat_hash_set<int> forbidden_stream_numbers;
103 for (const auto* seen_gemm : seen_gemms) {
104 int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm);
105 if (!forbidden_stream_numbers.contains(stream_num) &&
106 CanRunConcurrently(*seen_gemm, hlo, reachability)) {
107 forbidden_stream_numbers.insert(stream_num);
108 }
109 }
110
111 for (int stream_num = 0; stream_num < stream_assignment.StreamCount();
112 ++stream_num) {
113 if (!forbidden_stream_numbers.contains(stream_num)) {
114 return stream_num;
115 }
116 }
117 return stream_assignment.StreamCount();
118 }
119
120 } // namespace
121
AssignStreams(const HloModule & module)122 std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) {
123 auto stream_assignment = absl::make_unique<StreamAssignment>();
124 const HloComputation& computation = *module.entry_computation();
125 std::unique_ptr<HloReachabilityMap> reachability =
126 HloReachabilityMap::Build(&computation);
127 std::vector<const HloInstruction*> seen_gemms;
128 // The execution of different RNG Hlo instructions in the same module updates
129 // a common global variable. To avoid a race condition, we simply assign all
130 // RNG kernels to the same stream to make them run sequentially.
131 //
132 // TODO(b/111791052): If we remove such a common variable, we will need to
133 // clean up the code here.
134 int stream_num_for_rng = kInvalidStreamNum;
135 for (const auto* hlo : computation.MakeInstructionPostOrder()) {
136 // If we ever enable fusion of RNG instructions, we will need to extend this
137 // code to look inside a fused instruction.
138 int stream_num = (hlo->opcode() == HloOpcode::kRng &&
139 IsStreamNumValid(stream_num_for_rng))
140 ? stream_num_for_rng
141 : ComputeStreamToAssign(*hlo, *stream_assignment,
142 *reachability, seen_gemms);
143 if (IsStreamNumValid(stream_num)) {
144 stream_assignment->AssignStreamToHlo(hlo, stream_num);
145 if (hlo->opcode() == HloOpcode::kRng &&
146 !IsStreamNumValid(stream_num_for_rng)) {
147 stream_num_for_rng = stream_num;
148 }
149 }
150 if (ImplementedAsGemm(*hlo)) {
151 seen_gemms.push_back(hlo);
152 }
153 }
154 return stream_assignment;
155 }
156
157 } // namespace gpu
158 } // namespace xla
159