• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/xla/comparison_util.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_value.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/platform/logging.h"
40 
41 namespace xla {
42 
43 using absl::StrAppend;
44 
45 namespace {
46 
47 using FlatValueSet = absl::flat_hash_set<const HloValue*>;
48 
ComputeInputOutputAliasedValues(const HloValue & value,const HloDataflowAnalysis & dataflow,FlatValueSet & aliased_values)49 void ComputeInputOutputAliasedValues(const HloValue& value,
50                                      const HloDataflowAnalysis& dataflow,
51                                      FlatValueSet& aliased_values) {
52   const HloModule& module = dataflow.module();
53   const HloComputation& entry_computation = *module.entry_computation();
54   const HloInputOutputAliasConfig& io_alias_config =
55       module.input_output_alias_config();
56 
57   // If the value shows up in a root instruction, alias it with parameter
58   // instruction.
59   for (const HloPosition& pos : value.positions()) {
60     if (pos.instruction == entry_computation.root_instruction()) {
61       std::optional<HloInputOutputAliasConfig::Alias> aliased_input =
62           io_alias_config.GetAliasedParameter(pos.index);
63       if (aliased_input) {
64         aliased_values.insert(
65             &dataflow.GetUniqueValueAt(entry_computation.parameter_instruction(
66                                            aliased_input->parameter_number),
67                                        aliased_input->parameter_index));
68       }
69     }
70   }
71 }
72 
ComputeWhileAliasedValues(const HloValue & value,const HloDataflowAnalysis & dataflow,FlatValueSet & aliased_values)73 void ComputeWhileAliasedValues(const HloValue& value,
74                                const HloDataflowAnalysis& dataflow,
75                                FlatValueSet& aliased_values) {
76   VLOG(3) << "Compute kWhile aliases";
77   // Value is init of a while (use is while).
78   for (const HloUse& use : value.GetUses()) {
79     if (use.instruction->opcode() == HloOpcode::kWhile) {
80       // Determine the while value that this shares a buffer with.
81       const HloValue& while_value =
82           dataflow.GetUniqueValueAt(use.instruction, use.operand_index);
83       aliased_values.insert(&while_value);
84       VLOG(3) << "  value is init value to a while; must share buffer with "
85                  "while value "
86               << while_value;
87     }
88   }
89   // Value is a parameter of a while body/condition.
90   if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
91     const HloComputation* computation = value.defining_instruction()->parent();
92     const CallGraphNode& call_graph_node =
93         dataflow.call_graph().GetNode(computation);
94     for (const CallSite& callsite : call_graph_node.caller_callsites()) {
95       if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
96         // Call graph must have been flattened.
97         CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
98 
99         const HloValue& while_value = dataflow.GetUniqueValueAt(
100             callsite.instruction(), value.defining_index());
101         VLOG(3) << "  value is parameter value of the body or condition of a "
102                    "while; must share buffer with while value "
103                 << while_value;
104         aliased_values.insert(&while_value);
105       }
106     }
107   }
108   // Value is the root of a while body.
109   for (const HloPosition& position : value.positions()) {
110     if (!position.instruction->IsRoot()) continue;
111 
112     const HloComputation* computation = position.instruction->parent();
113     const CallGraphNode& call_graph_node =
114         dataflow.call_graph().GetNode(computation);
115 
116     for (const CallSite& callsite : call_graph_node.caller_callsites()) {
117       if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
118           callsite.instruction()->while_body() == computation) {
119         // Call graph must have been flattened.
120         CHECK_EQ(call_graph_node.caller_callsites().size(), 1)
121             << "Call graph must have been flattened.";
122 
123         const HloValue& while_value =
124             dataflow.GetUniqueValueAt(callsite.instruction(), position.index);
125         VLOG(3) << "  value @ " << position << " is root of "
126                 << callsite.instruction()->name()
127                 << "; body root and while value root must share buffer "
128                    "among them: "
129                 << while_value;
130         aliased_values.insert(&while_value);
131       }
132     }
133   }
134 }
135 
ComputeConditionalAliasedValues(const HloValue & value,const HloDataflowAnalysis & dataflow,FlatValueSet & aliased_values)136 void ComputeConditionalAliasedValues(const HloValue& value,
137                                      const HloDataflowAnalysis& dataflow,
138                                      FlatValueSet& aliased_values) {
139   VLOG(3) << "Compute kConditional aliases";
140   // Aliases the buffers of the true/false computations roots, with the one of
141   // the conditional.
142   for (const HloPosition& position : value.positions()) {
143     if (!position.instruction->IsRoot()) continue;
144 
145     const HloComputation* computation = position.instruction->parent();
146     const CallGraphNode& call_graph_node =
147         dataflow.call_graph().GetNode(computation);
148     for (const CallSite& callsite : call_graph_node.caller_callsites()) {
149       if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
150         // Call graph must have been flattened.
151         CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
152 
153         const HloValue& cond_value =
154             dataflow.GetUniqueValueAt(callsite.instruction(), position.index);
155         VLOG(3) << "  value @ " << position << " is root of "
156                 << callsite.instruction()->name()
157                 << "; branch computation roots must share buffer among them : "
158                 << cond_value;
159         aliased_values.insert(&cond_value);
160       }
161     }
162   }
163 }
164 
ComputeInPlaceOperationAliasedValues(const HloValue & value,const HloDataflowAnalysis & dataflow,FlatValueSet & aliased_values)165 void ComputeInPlaceOperationAliasedValues(const HloValue& value,
166                                           const HloDataflowAnalysis& dataflow,
167                                           FlatValueSet& aliased_values) {
168   VLOG(3) << "Compute aliases for in-place operations (e.g. "
169              "kDynamicUpdateSlice and kScatter)";
170   for (const HloPosition& position : value.positions()) {
171     HloInstruction* instruction = position.instruction;
172     for (const auto& operand_and_output_index :
173          HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
174       if (position.index == operand_and_output_index.second) {
175         const HloOperandIndex& operand_index = operand_and_output_index.first;
176         const HloValue& operand_value = dataflow.GetUniqueValueAt(
177             instruction->operand(operand_index.operand_number),
178             operand_index.operand_index);
179         VLOG(3) << " operand value " << operand_value << " aliases.";
180         aliased_values.insert(&operand_value);
181       }
182     }
183   }
184 
185   for (const HloUse& use : value.GetUses()) {
186     for (const auto& operand_and_output_index :
187          HloDataflowAnalysis::GetInPlaceInputOutputPairs(use.instruction)) {
188       const HloOperandIndex& operand_index = operand_and_output_index.first;
189       if (use.operand_number == operand_index.operand_number &&
190           use.operand_index == operand_index.operand_index) {
191         const HloValue& use_value = dataflow.GetUniqueValueAt(
192             use.instruction, operand_and_output_index.second);
193         VLOG(3) << " use value " << use_value << " aliases.";
194         aliased_values.insert(&use_value);
195       }
196     }
197   }
198 }
199 
200 // Compute and return a set of values that the given value must be aliased
201 // with due to HLO aliasing rules (including the value itself).
ComputeAliasedValues(const HloValue & value,const HloDataflowAnalysis & dataflow)202 FlatValueSet ComputeAliasedValues(const HloValue& value,
203                                   const HloDataflowAnalysis& dataflow) {
204   if (VLOG_IS_ON(2)) {
205     for (const HloUse& use : value.GetUses()) {
206       VLOG(2) << "Use of value " << value << ": " << use;
207     }
208   }
209 
210   FlatValueSet aliased_values{&value};
211   ComputeInputOutputAliasedValues(value, dataflow, aliased_values);
212   ComputeWhileAliasedValues(value, dataflow, aliased_values);
213   ComputeConditionalAliasedValues(value, dataflow, aliased_values);
214   ComputeInPlaceOperationAliasedValues(value, dataflow, aliased_values);
215   return aliased_values;
216 }
217 
CreateBuffers(const HloDataflowAnalysis & dataflow)218 std::vector<HloBuffer> CreateBuffers(const HloDataflowAnalysis& dataflow) {
219   const std::vector<HloValue*>& values = dataflow.values();
220   size_t num_buffers = values.size();
221   // The sets of values contained in each buffer.
222   std::vector<FlatValueSet> buffer_values(values.size());
223   // Maps values to the set of values with which they are aliased.
224   absl::flat_hash_map<const HloValue*, FlatValueSet*> value_to_set;
225   value_to_set.reserve(values.size());
226 
227   for (size_t i = 0; i < values.size(); ++i) {
228     buffer_values[i].insert(values[i]);
229     value_to_set[values[i]] = &buffer_values[i];
230   }
231 
232   // Merge together sets of HloValues which must be in the same HloBuffer
233   // because of aliasing rules (e.g. in-place kWhile instruction).
234   for (const HloValue* value : values) {
235     VLOG(3) << "Merging colocated values, value: " << *value;
236 
237     FlatValueSet aliased_values = ComputeAliasedValues(*value, dataflow);
238     if (aliased_values.size() < 2) continue;  // Fast path.
239 
240     // The sets of values that are transitively aliased together.
241     std::vector<std::pair<FlatValueSet*, HloValue::Id>> aliased_sets;
242     aliased_sets.reserve(aliased_values.size());
243     for (const HloValue* aliased : aliased_values) {
244       aliased_sets.push_back({value_to_set[aliased], aliased->id()});
245     }
246 
247     // Use the largest set to collect the union of the aliased sets (as it is
248     // more efficient to merge smaller sets into larger). Break ties using
249     // value ID to maintain determinism.
250     auto key = [](const auto& set_and_id) {
251       return std::make_pair(set_and_id.first->size(), -set_and_id.second);
252     };
253     FlatValueSet* union_set =
254         absl::c_max_element(aliased_sets, LessThanByKey(key))->first;
255 
256     for (auto& aliased_set_and_id : aliased_sets) {
257       FlatValueSet* aliased_set = aliased_set_and_id.first;
258       if ((aliased_set != union_set) && !aliased_set->empty()) {
259         for (const HloValue* aliased_value : *aliased_set) {
260           CHECK(union_set->insert(aliased_value).second);
261           value_to_set[aliased_value] = union_set;
262         }
263         aliased_set->clear();
264         --num_buffers;
265       }
266     }
267   }
268 
269   // Create a vector of HloBuffers, one for each non-empty set of values.
270   std::vector<HloBuffer> buffers;
271   buffers.reserve(num_buffers);
272 
273   for (const FlatValueSet& value_set : buffer_values) {
274     if (!value_set.empty()) {
275       HloBuffer::Id id = buffers.size();
276       buffers.push_back({id, HloValueSet(value_set).TakeValues()});
277     }
278   }
279 
280   CHECK_EQ(buffers.size(), num_buffers);
281   return buffers;
282 }
283 
284 }  // namespace
285 
HloAliasAnalysis(const HloModule * module)286 HloAliasAnalysis::HloAliasAnalysis(const HloModule* module) : module_(module) {}
287 
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index) const288 const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
289     const HloInstruction* instruction, const ShapeIndex& index) const {
290   std::vector<const HloBuffer*> buffers = ComputeBuffersAt(instruction, index);
291   CHECK_EQ(buffers.size(), 1);
292   return *buffers[0];
293 }
294 
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index)295 HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
296     const HloInstruction* instruction, const ShapeIndex& index) {
297   return GetBuffer(const_cast<const HloAliasAnalysis*>(this)
298                        ->GetUniqueBufferAt(instruction, index)
299                        .id());
300 }
301 
ComputeBuffersAt(const HloInstruction * instruction,const ShapeIndex & index) const302 std::vector<const HloBuffer*> HloAliasAnalysis::ComputeBuffersAt(
303     const HloInstruction* instruction, const ShapeIndex& index) const {
304   const HloValueSet& value_set =
305       dataflow_analysis_->GetValueSet(instruction, index);
306   std::vector<const HloBuffer*> buffers;
307   buffers.reserve(value_set.values().size());
308   for (const HloValue* value : value_set.values()) {
309     buffers.push_back(&GetBufferContainingValue(*value));
310   }
311 
312   // Sort and uniquify vector before returning.
313   absl::c_sort(buffers, HloBuffer::IdLessThan);
314   buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end());
315 
316   return buffers;
317 }
318 
Verify() const319 Status HloAliasAnalysis::Verify() const {
320   // Verify consistency between the value_to_buffer_ map and
321   // HloBuffer::values().
322   for (const auto& pair : value_to_buffer_) {
323     const HloValue* value = pair.first;
324     const HloBuffer& buffer = *pair.second;
325     TF_RET_CHECK(absl::c_linear_search(buffer.values(), value));
326   }
327 
328   for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
329     const HloBuffer& buffer = buffers_[id];
330     TF_RET_CHECK(buffer.id() == id);
331 
332     HloValue::Id last_value_id = -1;
333     for (const HloValue* value : buffer.values()) {
334       TF_RET_CHECK(GetBufferContainingValue(*value) == buffer);
335 
336       // Also verify the values in HloBuffer are unique and sorted by id.
337       TF_RET_CHECK(value->id() > last_value_id);
338       last_value_id = value->id();
339     }
340   }
341 
342   return OkStatus();
343 }
344 
ToString() const345 std::string HloAliasAnalysis::ToString() const {
346   std::string out =
347       absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
348   StrAppend(&out, "  Buffers at each position:\n");
349   for (const HloComputation* computation : module_->computations()) {
350     for (const HloInstruction* instruction : computation->instructions()) {
351       StrAppend(&out, "    ", instruction->name(), ":\n");
352       if (instruction->shape().IsTuple()) {
353         ShapeUtil::ForEachSubshape(
354             instruction->shape(),
355             [&out, &instruction, this](const Shape&, const ShapeIndex& index) {
356               StrAppend(&out, "      tuple index ", index.ToString(), ":\n");
357               for (const HloBuffer* buffer :
358                    ComputeBuffersAt(instruction, index)) {
359                 StrAppend(&out, "        ", buffer->ToString(), "\n");
360               }
361             });
362       } else {
363         for (const HloBuffer* buffer :
364              ComputeBuffersAt(instruction, /*index=*/{})) {
365           StrAppend(&out, "      ", buffer->ToString(), "\n");
366         }
367       }
368     }
369   }
370 
371   StrAppend(&out, "  Buffers:\n");
372   for (const HloBuffer& buffer : buffers()) {
373     StrAppend(&out, "    ", buffer.ToString(), "\n");
374     StrAppend(&out, "      positions:\n");
375     for (const HloPosition& position : buffer.ComputePositions()) {
376       StrAppend(&out, "        ", position.ToString(), "\n");
377     }
378   }
379 
380   return out;
381 }
382 
383 /* static */
Run(const HloModule * module,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer)384 StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
385     const HloModule* module,
386     const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) {
387   VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
388   XLA_VLOG_LINES(2, module->ToString());
389 
390   auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module));
391   TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
392                       HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
393                                                /*bitcast_defines_value=*/false,
394                                                can_share_buffer));
395 
396   size_t num_values = alias_analysis->dataflow_analysis_->values().size();
397   alias_analysis->buffers_ = CreateBuffers(alias_analysis->dataflow_analysis());
398   alias_analysis->value_to_buffer_.reserve(num_values);
399 
400   for (HloBuffer& buffer : alias_analysis->buffers_) {
401     for (const HloValue* value : buffer.values()) {
402       alias_analysis->value_to_buffer_[value] = &buffer;
403     }
404   }
405 
406   CHECK_EQ(alias_analysis->value_to_buffer_.size(), num_values);
407   TF_DCHECK_OK(alias_analysis->Verify());
408 
409   HloInstruction* root = module->entry_computation()->root_instruction();
410   ShapeUtil::ForEachSubshape(root->shape(), [&](const Shape& /*subshape*/,
411                                                 const ShapeIndex& index) {
412     std::vector<const HloBuffer*> buffers =
413         alias_analysis->ComputeBuffersAt(root, index);
414     alias_analysis->live_out_buffers_.insert(buffers.begin(), buffers.end());
415   });
416 
417   XLA_VLOG_LINES(2, alias_analysis->ToString());
418   return std::move(alias_analysis);
419 }
420 
421 }  // namespace xla
422