• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
17 
18 #include "absl/strings/str_format.h"
19 
20 namespace xla {
21 /*static*/
Run(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const HloComputation * computation,bool module_scoped_analysis)22 StatusOr<std::unique_ptr<HloLiveRange>> HloLiveRange::Run(
23     const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
24     const HloComputation* computation, bool module_scoped_analysis) {
25   std::unique_ptr<HloLiveRange> hlo_live_range(
26       new HloLiveRange(schedule, alias_analysis, module_scoped_analysis));
27   hlo_live_range->schedule_end_time_ =
28       hlo_live_range->FlattenSchedule(*computation, 0);
29   hlo_live_range->CalculateBufferStartEndMap();
30   hlo_live_range->NormalizeAliasedBuffers();
31   return std::move(hlo_live_range);
32 }
33 
NormalizeAliasedBuffers()34 void HloLiveRange::NormalizeAliasedBuffers() {
35   for (const HloBuffer& hlo_buffer : alias_analysis_.buffers()) {
36     std::vector<const HloValue*> aliased_buffers;
37     for (const HloValue* hlo_value : hlo_buffer.values()) {
38       if (buffer_live_ranges_.contains(hlo_value)) {
39         aliased_buffers.push_back(hlo_value);
40       }
41     }
42     absl::c_sort(
43         aliased_buffers, [&](const HloValue* value1, const HloValue* value2) {
44           const TimeBound& live_range1 = buffer_live_ranges_.at(value1);
45           const TimeBound& live_range2 = buffer_live_ranges_.at(value2);
46 
47           return std::forward_as_tuple(live_range1.start, live_range1.end) <
48                  std::forward_as_tuple(live_range2.start, live_range2.end);
49         });
50 
51     for (int64 i = 0; i + 1 < aliased_buffers.size(); ++i) {
52       const HloValue* value1 = aliased_buffers[i];
53       const HloValue* value2 = aliased_buffers[i + 1];
54       TimeBound& live_range1 = buffer_live_ranges_[value1];
55       TimeBound& live_range2 = buffer_live_ranges_[value2];
56       if (live_range1.start == live_range2.start) {
57         // If value1 has the same start time as value2, make value1 disappear
58         // by setting the end time same as start time:
59         //
60         // Before:
61         // +----+           value1
62         // +----------+     value2
63         //
64         // After:
65         // +                value1
66         // +----------+     value2
67         //
68         // Note that only when heap simulator runs before copy insertion can
69         // this happen where one instruction defines multiple aliased buffers
70         // -- This is illegle to execute and can be fixed by copy insertion
71         // later.
72         live_range1.end = live_range2.end;
73         continue;
74       }
75 
76       if (live_range1.end < live_range2.start) {
77         continue;
78       }
79 
80       if (live_range1.end > live_range2.end) {
81         live_range2.end = live_range1.end;
82       }
83       live_range1.end = live_range2.start - 1;
84     }
85   }
86 }
87 
88 // FlattenSchedule walks through the computation and tracks down the ordinal
89 // number of each instruction in the schedule.
FlattenSchedule(const HloComputation & computation,int64 start_time)90 int64 HloLiveRange::FlattenSchedule(const HloComputation& computation,
91                                     int64 start_time) {
92   if (!schedule_.is_computation_scheduled(&computation)) {
93     total_order_scheduled_ = false;
94     return start_time;
95   }
96 
97   const HloInstructionSequence& instruction_sequence =
98       schedule_.sequence(&computation);
99   int64 time = start_time;
100   for (HloInstruction* instruction : instruction_sequence.instructions()) {
101     if (module_scoped_analysis_) {
102       // Recurse into sub computations if running with module scoped analysis
103       // mode.
104       if (instruction->opcode() == HloOpcode::kCall ||
105           instruction->opcode() == HloOpcode::kConditional) {
106         for (const HloComputation* called_computation :
107              instruction->called_computations()) {
108           time = FlattenSchedule(*called_computation, time);
109         }
110       }
111       if (instruction->opcode() == HloOpcode::kWhile) {
112         time = FlattenSchedule(*instruction->while_condition(), time);
113         time = FlattenSchedule(*instruction->while_body(), time);
114       }
115     }
116     if (instruction_schedule_.count(instruction) != 0) {
117       continue;
118     }
119     instruction_schedule_.insert({instruction, time++});
120     flattened_instruction_sequence_.push_back(instruction);
121   }
122   computation_span_times_.try_emplace(&computation,
123                                       TimeBound{start_time, time});
124   DCHECK_EQ(instruction_schedule_.size(),
125             flattened_instruction_sequence_.size());
126   DCHECK_LE(instruction_schedule_.size(), time);
127   return time;
128 }
129 
CalculateBufferStartEndMap()130 void HloLiveRange::CalculateBufferStartEndMap() {
131   for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
132     // Ignore buffers that are not defined.
133     if (instruction_schedule_.count(value->defining_instruction()) == 0) {
134       continue;
135     }
136 
137     int64 buffer_start_time = instruction_schedule_[value->instruction()];
138 
139     int64 buffer_end_time = -1;
140     for (const HloUse& use : value->uses()) {
141       const HloInstruction* used = use.instruction;
142       // As an optimization, we deem a while's init value's live range ends as
143       // soon as the loop body starts. This optimization is only applicable in
144       // module scoped mode.
145       if (module_scoped_analysis_ && used->opcode() == HloOpcode::kWhile) {
146         // The current live range is at the end of the while, move it to the
147         // beginning of the body.
148         used = used->while_body()->parameter_instruction(0);
149         VLOG(1) << "Moved value " << value->ToShortString()
150                 << " to while param: " << used->ToString();
151       }
152       if (instruction_schedule_.count(used) == 0) {
153         // We didn't track the instruction `used`. This happens when we do
154         // computation scope (versus module scope) heap simulation and when
155         // the used instruction is outside of the computation being simulated.
156         continue;
157       }
158       buffer_end_time = std::max(buffer_end_time, instruction_schedule_[used]);
159     }
160 
161     // Parameters are defined at the beginning of the computation. This prevents
162     // any instruction that's scheduled before the parameter clobbers the
163     // parameter's buffer.
164     if (value->instruction()->opcode() == HloOpcode::kParameter) {
165       const HloComputation* computation = value->instruction()->parent();
166       auto it = computation_span_times_.find(computation);
167       if (it != computation_span_times_.end()) {
168         buffer_start_time = std::min(buffer_start_time, it->second.start);
169       }
170     }
171 
172     if (buffer_end_time == -1) {
173       buffer_end_time = buffer_start_time;
174     }
175 
176     for (const HloPosition& position : value->positions()) {
177       const HloComputation* position_comp = position.instruction->parent();
178       // If this instruction lives out, the live range of the instruction
179       // should be extended to the end of the computation.
180       if (position.instruction == position_comp->root_instruction()) {
181         auto it = computation_span_times_.find(position_comp);
182         if (it == computation_span_times_.end()) {
183           continue;
184         }
185         buffer_end_time = std::max(buffer_end_time, it->second.end);
186       }
187     }
188 
189     const HloModule* module = value->instruction()->parent()->parent();
190 
191     // Readonly entry parameters (parameters that don't alias) live across whole
192     // computation.
193     if (value->instruction()->opcode() == HloOpcode::kParameter &&
194         value->instruction()->parent() == module->entry_computation() &&
195         !module->input_output_alias_config().ParameterHasAlias(
196             value->instruction()->parameter_number(), value->index())) {
197       buffer_end_time = schedule_end_time_;
198     }
199 
200     CHECK(buffer_start_time <= buffer_end_time)
201         << buffer_start_time << ", " << buffer_end_time
202         << value->instruction()->ToString();
203 
204     auto& live_range = buffer_live_ranges_[value];
205     live_range.start = buffer_start_time;
206     live_range.end = buffer_end_time;
207   }
208 }
209 
ToString() const210 std::string HloLiveRange::ToString() const {
211   std::string output;
212   absl::StrAppendFormat(&output, "HloLiveRange (max %d):\n",
213                         schedule_end_time_);
214   absl::StrAppendFormat(&output, "  InstructionSequence:\n");
215   auto& instructions = flattened_instruction_sequence().instructions();
216   for (int64 i = 0; i < instructions.size(); ++i) {
217     absl::StrAppendFormat(&output, "    %d:%s\n", i, instructions[i]->name());
218   }
219 
220   absl::StrAppendFormat(&output, "  BufferLiveRange:\n");
221 
222   for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
223     auto it = buffer_live_ranges_.find(value);
224     if (it != buffer_live_ranges_.end()) {
225       absl::StrAppendFormat(
226           &output, "    %s%s:%d-%d\n", value->instruction()->name(),
227           value->index().ToString(), it->second.start, it->second.end);
228     }
229   }
230 
231   return output;
232 }
233 
234 }  // namespace xla
235