1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
17
18 #include "absl/strings/str_format.h"
19
20 namespace xla {
21 /*static*/
Run(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const HloComputation * computation,bool module_scoped_analysis)22 StatusOr<std::unique_ptr<HloLiveRange>> HloLiveRange::Run(
23 const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
24 const HloComputation* computation, bool module_scoped_analysis) {
25 std::unique_ptr<HloLiveRange> hlo_live_range(
26 new HloLiveRange(schedule, alias_analysis, module_scoped_analysis));
27 hlo_live_range->schedule_end_time_ =
28 hlo_live_range->FlattenSchedule(*computation, 0);
29 hlo_live_range->CalculateBufferStartEndMap();
30 hlo_live_range->NormalizeAliasedBuffers();
31 return std::move(hlo_live_range);
32 }
33
NormalizeAliasedBuffers()34 void HloLiveRange::NormalizeAliasedBuffers() {
35 for (const HloBuffer& hlo_buffer : alias_analysis_.buffers()) {
36 std::vector<const HloValue*> aliased_buffers;
37 for (const HloValue* hlo_value : hlo_buffer.values()) {
38 if (buffer_live_ranges_.contains(hlo_value)) {
39 aliased_buffers.push_back(hlo_value);
40 }
41 }
42 absl::c_sort(
43 aliased_buffers, [&](const HloValue* value1, const HloValue* value2) {
44 const TimeBound& live_range1 = buffer_live_ranges_.at(value1);
45 const TimeBound& live_range2 = buffer_live_ranges_.at(value2);
46
47 return std::forward_as_tuple(live_range1.start, live_range1.end) <
48 std::forward_as_tuple(live_range2.start, live_range2.end);
49 });
50
51 for (int64 i = 0; i + 1 < aliased_buffers.size(); ++i) {
52 const HloValue* value1 = aliased_buffers[i];
53 const HloValue* value2 = aliased_buffers[i + 1];
54 TimeBound& live_range1 = buffer_live_ranges_[value1];
55 TimeBound& live_range2 = buffer_live_ranges_[value2];
56 if (live_range1.start == live_range2.start) {
57 // If value1 has the same start time as value2, make value1 disappear
58 // by setting the end time same as start time:
59 //
60 // Before:
61 // +----+ value1
62 // +----------+ value2
63 //
64 // After:
65 // + value1
66 // +----------+ value2
67 //
68 // Note that only when heap simulator runs before copy insertion can
69 // this happen where one instruction defines multiple aliased buffers
70 // -- This is illegle to execute and can be fixed by copy insertion
71 // later.
72 live_range1.end = live_range2.end;
73 continue;
74 }
75
76 if (live_range1.end < live_range2.start) {
77 continue;
78 }
79
80 if (live_range1.end > live_range2.end) {
81 live_range2.end = live_range1.end;
82 }
83 live_range1.end = live_range2.start - 1;
84 }
85 }
86 }
87
88 // FlattenSchedule walks through the computation and tracks down the ordinal
89 // number of each instruction in the schedule.
FlattenSchedule(const HloComputation & computation,int64 start_time)90 int64 HloLiveRange::FlattenSchedule(const HloComputation& computation,
91 int64 start_time) {
92 if (!schedule_.is_computation_scheduled(&computation)) {
93 total_order_scheduled_ = false;
94 return start_time;
95 }
96
97 const HloInstructionSequence& instruction_sequence =
98 schedule_.sequence(&computation);
99 int64 time = start_time;
100 for (HloInstruction* instruction : instruction_sequence.instructions()) {
101 if (module_scoped_analysis_) {
102 // Recurse into sub computations if running with module scoped analysis
103 // mode.
104 if (instruction->opcode() == HloOpcode::kCall ||
105 instruction->opcode() == HloOpcode::kConditional) {
106 for (const HloComputation* called_computation :
107 instruction->called_computations()) {
108 time = FlattenSchedule(*called_computation, time);
109 }
110 }
111 if (instruction->opcode() == HloOpcode::kWhile) {
112 time = FlattenSchedule(*instruction->while_condition(), time);
113 time = FlattenSchedule(*instruction->while_body(), time);
114 }
115 }
116 if (instruction_schedule_.count(instruction) != 0) {
117 continue;
118 }
119 instruction_schedule_.insert({instruction, time++});
120 flattened_instruction_sequence_.push_back(instruction);
121 }
122 computation_span_times_.try_emplace(&computation,
123 TimeBound{start_time, time});
124 DCHECK_EQ(instruction_schedule_.size(),
125 flattened_instruction_sequence_.size());
126 DCHECK_LE(instruction_schedule_.size(), time);
127 return time;
128 }
129
CalculateBufferStartEndMap()130 void HloLiveRange::CalculateBufferStartEndMap() {
131 for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
132 // Ignore buffers that are not defined.
133 if (instruction_schedule_.count(value->defining_instruction()) == 0) {
134 continue;
135 }
136
137 int64 buffer_start_time = instruction_schedule_[value->instruction()];
138
139 int64 buffer_end_time = -1;
140 for (const HloUse& use : value->uses()) {
141 const HloInstruction* used = use.instruction;
142 // As an optimization, we deem a while's init value's live range ends as
143 // soon as the loop body starts. This optimization is only applicable in
144 // module scoped mode.
145 if (module_scoped_analysis_ && used->opcode() == HloOpcode::kWhile) {
146 // The current live range is at the end of the while, move it to the
147 // beginning of the body.
148 used = used->while_body()->parameter_instruction(0);
149 VLOG(1) << "Moved value " << value->ToShortString()
150 << " to while param: " << used->ToString();
151 }
152 if (instruction_schedule_.count(used) == 0) {
153 // We didn't track the instruction `used`. This happens when we do
154 // computation scope (versus module scope) heap simulation and when
155 // the used instruction is outside of the computation being simulated.
156 continue;
157 }
158 buffer_end_time = std::max(buffer_end_time, instruction_schedule_[used]);
159 }
160
161 // Parameters are defined at the beginning of the computation. This prevents
162 // any instruction that's scheduled before the parameter clobbers the
163 // parameter's buffer.
164 if (value->instruction()->opcode() == HloOpcode::kParameter) {
165 const HloComputation* computation = value->instruction()->parent();
166 auto it = computation_span_times_.find(computation);
167 if (it != computation_span_times_.end()) {
168 buffer_start_time = std::min(buffer_start_time, it->second.start);
169 }
170 }
171
172 if (buffer_end_time == -1) {
173 buffer_end_time = buffer_start_time;
174 }
175
176 for (const HloPosition& position : value->positions()) {
177 const HloComputation* position_comp = position.instruction->parent();
178 // If this instruction lives out, the live range of the instruction
179 // should be extended to the end of the computation.
180 if (position.instruction == position_comp->root_instruction()) {
181 auto it = computation_span_times_.find(position_comp);
182 if (it == computation_span_times_.end()) {
183 continue;
184 }
185 buffer_end_time = std::max(buffer_end_time, it->second.end);
186 }
187 }
188
189 const HloModule* module = value->instruction()->parent()->parent();
190
191 // Readonly entry parameters (parameters that don't alias) live across whole
192 // computation.
193 if (value->instruction()->opcode() == HloOpcode::kParameter &&
194 value->instruction()->parent() == module->entry_computation() &&
195 !module->input_output_alias_config().ParameterHasAlias(
196 value->instruction()->parameter_number(), value->index())) {
197 buffer_end_time = schedule_end_time_;
198 }
199
200 CHECK(buffer_start_time <= buffer_end_time)
201 << buffer_start_time << ", " << buffer_end_time
202 << value->instruction()->ToString();
203
204 auto& live_range = buffer_live_ranges_[value];
205 live_range.start = buffer_start_time;
206 live_range.end = buffer_end_time;
207 }
208 }
209
ToString() const210 std::string HloLiveRange::ToString() const {
211 std::string output;
212 absl::StrAppendFormat(&output, "HloLiveRange (max %d):\n",
213 schedule_end_time_);
214 absl::StrAppendFormat(&output, " InstructionSequence:\n");
215 auto& instructions = flattened_instruction_sequence().instructions();
216 for (int64 i = 0; i < instructions.size(); ++i) {
217 absl::StrAppendFormat(&output, " %d:%s\n", i, instructions[i]->name());
218 }
219
220 absl::StrAppendFormat(&output, " BufferLiveRange:\n");
221
222 for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
223 auto it = buffer_live_ranges_.find(value);
224 if (it != buffer_live_ranges_.end()) {
225 absl::StrAppendFormat(
226 &output, " %s%s:%d-%d\n", value->instruction()->name(),
227 value->index().ToString(), it->second.start, it->second.end);
228 }
229 }
230
231 return output;
232 }
233
234 } // namespace xla
235