1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/xla/comparison_util.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_value.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/platform/logging.h"
40
41 namespace xla {
42
43 using absl::StrAppend;
44
45 namespace {
46
47 using FlatValueSet = absl::flat_hash_set<const HloValue*>;
48
ComputeInputOutputAliasedValues(const HloValue & value,const HloDataflowAnalysis & dataflow,FlatValueSet & aliased_values)49 void ComputeInputOutputAliasedValues(const HloValue& value,
50 const HloDataflowAnalysis& dataflow,
51 FlatValueSet& aliased_values) {
52 const HloModule& module = dataflow.module();
53 const HloComputation& entry_computation = *module.entry_computation();
54 const HloInputOutputAliasConfig& io_alias_config =
55 module.input_output_alias_config();
56
57 // If the value shows up in a root instruction, alias it with parameter
58 // instruction.
59 for (const HloPosition& pos : value.positions()) {
60 if (pos.instruction == entry_computation.root_instruction()) {
61 std::optional<HloInputOutputAliasConfig::Alias> aliased_input =
62 io_alias_config.GetAliasedParameter(pos.index);
63 if (aliased_input) {
64 aliased_values.insert(
65 &dataflow.GetUniqueValueAt(entry_computation.parameter_instruction(
66 aliased_input->parameter_number),
67 aliased_input->parameter_index));
68 }
69 }
70 }
71 }
72
ComputeWhileAliasedValues(const HloValue & value,const HloDataflowAnalysis & dataflow,FlatValueSet & aliased_values)73 void ComputeWhileAliasedValues(const HloValue& value,
74 const HloDataflowAnalysis& dataflow,
75 FlatValueSet& aliased_values) {
76 VLOG(3) << "Compute kWhile aliases";
77 // Value is init of a while (use is while).
78 for (const HloUse& use : value.GetUses()) {
79 if (use.instruction->opcode() == HloOpcode::kWhile) {
80 // Determine the while value that this shares a buffer with.
81 const HloValue& while_value =
82 dataflow.GetUniqueValueAt(use.instruction, use.operand_index);
83 aliased_values.insert(&while_value);
84 VLOG(3) << " value is init value to a while; must share buffer with "
85 "while value "
86 << while_value;
87 }
88 }
89 // Value is a parameter of a while body/condition.
90 if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
91 const HloComputation* computation = value.defining_instruction()->parent();
92 const CallGraphNode& call_graph_node =
93 dataflow.call_graph().GetNode(computation);
94 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
95 if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
96 // Call graph must have been flattened.
97 CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
98
99 const HloValue& while_value = dataflow.GetUniqueValueAt(
100 callsite.instruction(), value.defining_index());
101 VLOG(3) << " value is parameter value of the body or condition of a "
102 "while; must share buffer with while value "
103 << while_value;
104 aliased_values.insert(&while_value);
105 }
106 }
107 }
108 // Value is the root of a while body.
109 for (const HloPosition& position : value.positions()) {
110 if (!position.instruction->IsRoot()) continue;
111
112 const HloComputation* computation = position.instruction->parent();
113 const CallGraphNode& call_graph_node =
114 dataflow.call_graph().GetNode(computation);
115
116 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
117 if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
118 callsite.instruction()->while_body() == computation) {
119 // Call graph must have been flattened.
120 CHECK_EQ(call_graph_node.caller_callsites().size(), 1)
121 << "Call graph must have been flattened.";
122
123 const HloValue& while_value =
124 dataflow.GetUniqueValueAt(callsite.instruction(), position.index);
125 VLOG(3) << " value @ " << position << " is root of "
126 << callsite.instruction()->name()
127 << "; body root and while value root must share buffer "
128 "among them: "
129 << while_value;
130 aliased_values.insert(&while_value);
131 }
132 }
133 }
134 }
135
ComputeConditionalAliasedValues(const HloValue & value,const HloDataflowAnalysis & dataflow,FlatValueSet & aliased_values)136 void ComputeConditionalAliasedValues(const HloValue& value,
137 const HloDataflowAnalysis& dataflow,
138 FlatValueSet& aliased_values) {
139 VLOG(3) << "Compute kConditional aliases";
140 // Aliases the buffers of the true/false computations roots, with the one of
141 // the conditional.
142 for (const HloPosition& position : value.positions()) {
143 if (!position.instruction->IsRoot()) continue;
144
145 const HloComputation* computation = position.instruction->parent();
146 const CallGraphNode& call_graph_node =
147 dataflow.call_graph().GetNode(computation);
148 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
149 if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
150 // Call graph must have been flattened.
151 CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
152
153 const HloValue& cond_value =
154 dataflow.GetUniqueValueAt(callsite.instruction(), position.index);
155 VLOG(3) << " value @ " << position << " is root of "
156 << callsite.instruction()->name()
157 << "; branch computation roots must share buffer among them : "
158 << cond_value;
159 aliased_values.insert(&cond_value);
160 }
161 }
162 }
163 }
164
ComputeInPlaceOperationAliasedValues(const HloValue & value,const HloDataflowAnalysis & dataflow,FlatValueSet & aliased_values)165 void ComputeInPlaceOperationAliasedValues(const HloValue& value,
166 const HloDataflowAnalysis& dataflow,
167 FlatValueSet& aliased_values) {
168 VLOG(3) << "Compute aliases for in-place operations (e.g. "
169 "kDynamicUpdateSlice and kScatter)";
170 for (const HloPosition& position : value.positions()) {
171 HloInstruction* instruction = position.instruction;
172 for (const auto& operand_and_output_index :
173 HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
174 if (position.index == operand_and_output_index.second) {
175 const HloOperandIndex& operand_index = operand_and_output_index.first;
176 const HloValue& operand_value = dataflow.GetUniqueValueAt(
177 instruction->operand(operand_index.operand_number),
178 operand_index.operand_index);
179 VLOG(3) << " operand value " << operand_value << " aliases.";
180 aliased_values.insert(&operand_value);
181 }
182 }
183 }
184
185 for (const HloUse& use : value.GetUses()) {
186 for (const auto& operand_and_output_index :
187 HloDataflowAnalysis::GetInPlaceInputOutputPairs(use.instruction)) {
188 const HloOperandIndex& operand_index = operand_and_output_index.first;
189 if (use.operand_number == operand_index.operand_number &&
190 use.operand_index == operand_index.operand_index) {
191 const HloValue& use_value = dataflow.GetUniqueValueAt(
192 use.instruction, operand_and_output_index.second);
193 VLOG(3) << " use value " << use_value << " aliases.";
194 aliased_values.insert(&use_value);
195 }
196 }
197 }
198 }
199
200 // Compute and return a set of values that the given value must be aliased
201 // with due to HLO aliasing rules (including the value itself).
ComputeAliasedValues(const HloValue & value,const HloDataflowAnalysis & dataflow)202 FlatValueSet ComputeAliasedValues(const HloValue& value,
203 const HloDataflowAnalysis& dataflow) {
204 if (VLOG_IS_ON(2)) {
205 for (const HloUse& use : value.GetUses()) {
206 VLOG(2) << "Use of value " << value << ": " << use;
207 }
208 }
209
210 FlatValueSet aliased_values{&value};
211 ComputeInputOutputAliasedValues(value, dataflow, aliased_values);
212 ComputeWhileAliasedValues(value, dataflow, aliased_values);
213 ComputeConditionalAliasedValues(value, dataflow, aliased_values);
214 ComputeInPlaceOperationAliasedValues(value, dataflow, aliased_values);
215 return aliased_values;
216 }
217
CreateBuffers(const HloDataflowAnalysis & dataflow)218 std::vector<HloBuffer> CreateBuffers(const HloDataflowAnalysis& dataflow) {
219 const std::vector<HloValue*>& values = dataflow.values();
220 size_t num_buffers = values.size();
221 // The sets of values contained in each buffer.
222 std::vector<FlatValueSet> buffer_values(values.size());
223 // Maps values to the set of values with which they are aliased.
224 absl::flat_hash_map<const HloValue*, FlatValueSet*> value_to_set;
225 value_to_set.reserve(values.size());
226
227 for (size_t i = 0; i < values.size(); ++i) {
228 buffer_values[i].insert(values[i]);
229 value_to_set[values[i]] = &buffer_values[i];
230 }
231
232 // Merge together sets of HloValues which must be in the same HloBuffer
233 // because of aliasing rules (e.g. in-place kWhile instruction).
234 for (const HloValue* value : values) {
235 VLOG(3) << "Merging colocated values, value: " << *value;
236
237 FlatValueSet aliased_values = ComputeAliasedValues(*value, dataflow);
238 if (aliased_values.size() < 2) continue; // Fast path.
239
240 // The sets of values that are transitively aliased together.
241 std::vector<std::pair<FlatValueSet*, HloValue::Id>> aliased_sets;
242 aliased_sets.reserve(aliased_values.size());
243 for (const HloValue* aliased : aliased_values) {
244 aliased_sets.push_back({value_to_set[aliased], aliased->id()});
245 }
246
247 // Use the largest set to collect the union of the aliased sets (as it is
248 // more efficient to merge smaller sets into larger). Break ties using
249 // value ID to maintain determinism.
250 auto key = [](const auto& set_and_id) {
251 return std::make_pair(set_and_id.first->size(), -set_and_id.second);
252 };
253 FlatValueSet* union_set =
254 absl::c_max_element(aliased_sets, LessThanByKey(key))->first;
255
256 for (auto& aliased_set_and_id : aliased_sets) {
257 FlatValueSet* aliased_set = aliased_set_and_id.first;
258 if ((aliased_set != union_set) && !aliased_set->empty()) {
259 for (const HloValue* aliased_value : *aliased_set) {
260 CHECK(union_set->insert(aliased_value).second);
261 value_to_set[aliased_value] = union_set;
262 }
263 aliased_set->clear();
264 --num_buffers;
265 }
266 }
267 }
268
269 // Create a vector of HloBuffers, one for each non-empty set of values.
270 std::vector<HloBuffer> buffers;
271 buffers.reserve(num_buffers);
272
273 for (const FlatValueSet& value_set : buffer_values) {
274 if (!value_set.empty()) {
275 HloBuffer::Id id = buffers.size();
276 buffers.push_back({id, HloValueSet(value_set).TakeValues()});
277 }
278 }
279
280 CHECK_EQ(buffers.size(), num_buffers);
281 return buffers;
282 }
283
284 } // namespace
285
HloAliasAnalysis(const HloModule * module)286 HloAliasAnalysis::HloAliasAnalysis(const HloModule* module) : module_(module) {}
287
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index) const288 const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
289 const HloInstruction* instruction, const ShapeIndex& index) const {
290 std::vector<const HloBuffer*> buffers = ComputeBuffersAt(instruction, index);
291 CHECK_EQ(buffers.size(), 1);
292 return *buffers[0];
293 }
294
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index)295 HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
296 const HloInstruction* instruction, const ShapeIndex& index) {
297 return GetBuffer(const_cast<const HloAliasAnalysis*>(this)
298 ->GetUniqueBufferAt(instruction, index)
299 .id());
300 }
301
ComputeBuffersAt(const HloInstruction * instruction,const ShapeIndex & index) const302 std::vector<const HloBuffer*> HloAliasAnalysis::ComputeBuffersAt(
303 const HloInstruction* instruction, const ShapeIndex& index) const {
304 const HloValueSet& value_set =
305 dataflow_analysis_->GetValueSet(instruction, index);
306 std::vector<const HloBuffer*> buffers;
307 buffers.reserve(value_set.values().size());
308 for (const HloValue* value : value_set.values()) {
309 buffers.push_back(&GetBufferContainingValue(*value));
310 }
311
312 // Sort and uniquify vector before returning.
313 absl::c_sort(buffers, HloBuffer::IdLessThan);
314 buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end());
315
316 return buffers;
317 }
318
Verify() const319 Status HloAliasAnalysis::Verify() const {
320 // Verify consistency between the value_to_buffer_ map and
321 // HloBuffer::values().
322 for (const auto& pair : value_to_buffer_) {
323 const HloValue* value = pair.first;
324 const HloBuffer& buffer = *pair.second;
325 TF_RET_CHECK(absl::c_linear_search(buffer.values(), value));
326 }
327
328 for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
329 const HloBuffer& buffer = buffers_[id];
330 TF_RET_CHECK(buffer.id() == id);
331
332 HloValue::Id last_value_id = -1;
333 for (const HloValue* value : buffer.values()) {
334 TF_RET_CHECK(GetBufferContainingValue(*value) == buffer);
335
336 // Also verify the values in HloBuffer are unique and sorted by id.
337 TF_RET_CHECK(value->id() > last_value_id);
338 last_value_id = value->id();
339 }
340 }
341
342 return OkStatus();
343 }
344
ToString() const345 std::string HloAliasAnalysis::ToString() const {
346 std::string out =
347 absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
348 StrAppend(&out, " Buffers at each position:\n");
349 for (const HloComputation* computation : module_->computations()) {
350 for (const HloInstruction* instruction : computation->instructions()) {
351 StrAppend(&out, " ", instruction->name(), ":\n");
352 if (instruction->shape().IsTuple()) {
353 ShapeUtil::ForEachSubshape(
354 instruction->shape(),
355 [&out, &instruction, this](const Shape&, const ShapeIndex& index) {
356 StrAppend(&out, " tuple index ", index.ToString(), ":\n");
357 for (const HloBuffer* buffer :
358 ComputeBuffersAt(instruction, index)) {
359 StrAppend(&out, " ", buffer->ToString(), "\n");
360 }
361 });
362 } else {
363 for (const HloBuffer* buffer :
364 ComputeBuffersAt(instruction, /*index=*/{})) {
365 StrAppend(&out, " ", buffer->ToString(), "\n");
366 }
367 }
368 }
369 }
370
371 StrAppend(&out, " Buffers:\n");
372 for (const HloBuffer& buffer : buffers()) {
373 StrAppend(&out, " ", buffer.ToString(), "\n");
374 StrAppend(&out, " positions:\n");
375 for (const HloPosition& position : buffer.ComputePositions()) {
376 StrAppend(&out, " ", position.ToString(), "\n");
377 }
378 }
379
380 return out;
381 }
382
383 /* static */
Run(const HloModule * module,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer)384 StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
385 const HloModule* module,
386 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) {
387 VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
388 XLA_VLOG_LINES(2, module->ToString());
389
390 auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module));
391 TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
392 HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
393 /*bitcast_defines_value=*/false,
394 can_share_buffer));
395
396 size_t num_values = alias_analysis->dataflow_analysis_->values().size();
397 alias_analysis->buffers_ = CreateBuffers(alias_analysis->dataflow_analysis());
398 alias_analysis->value_to_buffer_.reserve(num_values);
399
400 for (HloBuffer& buffer : alias_analysis->buffers_) {
401 for (const HloValue* value : buffer.values()) {
402 alias_analysis->value_to_buffer_[value] = &buffer;
403 }
404 }
405
406 CHECK_EQ(alias_analysis->value_to_buffer_.size(), num_values);
407 TF_DCHECK_OK(alias_analysis->Verify());
408
409 HloInstruction* root = module->entry_computation()->root_instruction();
410 ShapeUtil::ForEachSubshape(root->shape(), [&](const Shape& /*subshape*/,
411 const ShapeIndex& index) {
412 std::vector<const HloBuffer*> buffers =
413 alias_analysis->ComputeBuffersAt(root, index);
414 alias_analysis->live_out_buffers_.insert(buffers.begin(), buffers.end());
415 });
416
417 XLA_VLOG_LINES(2, alias_analysis->ToString());
418 return std::move(alias_analysis);
419 }
420
421 } // namespace xla
422