• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/logical_buffer_analysis.h"
17 
18 #include <utility>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace xla {
28 
29 namespace {
30 
31 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)32 void GatherFusionInstructions(
33     HloInstruction* instruction,
34     std::vector<HloInstruction*>* fusion_instructions) {
35   CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
36   for (auto* fused : instruction->fused_instructions()) {
37     if (fused->opcode() == HloOpcode::kFusion) {
38       GatherFusionInstructions(fused, fusion_instructions);
39     }
40   }
41   fusion_instructions->push_back(instruction);
42 }
43 
44 }  // namespace
45 
46 /* static */ StatusOr<std::unique_ptr<LogicalBufferAnalysis>>
Run(const HloModule * module)47 LogicalBufferAnalysis::Run(const HloModule* module) {
48   std::unique_ptr<LogicalBufferAnalysis> analysis(
49       new LogicalBufferAnalysis(module));
50   TF_RETURN_IF_ERROR(analysis->Analyze());
51   return std::move(analysis);
52 }
53 
Analyze()54 Status LogicalBufferAnalysis::Analyze() {
55   // Empirically we usually have a few more logical buffers than instructions,
56   // so reserve 10% more than the number of instructions to avoid frequent
57   // resizes.
58   logical_buffers_.clear();
59   logical_buffers_.reserve((module_->instruction_count() * 11) / 10);
60 
61   // We filter out fusion computations, and get to them through fusion
62   // instructions. This is because it's possible to have orphaned (unreachable)
63   // fusion computations, and we don't want to try to assign buffers to those.
64   std::vector<HloInstruction*> fusion_instructions;
65   for (auto* computation : module_->MakeNonfusionComputations()) {
66     TF_RETURN_IF_ERROR(computation->Accept(this));
67     for (auto* instruction : computation->instructions()) {
68       if (instruction->opcode() != HloOpcode::kFusion) {
69         continue;
70       }
71       GatherFusionInstructions(instruction, &fusion_instructions);
72     }
73   }
74   for (auto* instruction : fusion_instructions) {
75     TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
76   }
77   return Status::OK();
78 }
79 
GetBuffer(LogicalBuffer::Id id) const80 LogicalBuffer& LogicalBufferAnalysis::GetBuffer(LogicalBuffer::Id id) const {
81   CHECK_GE(id, 0);
82   CHECK_LT(id, logical_buffers_.size());
83   return *logical_buffers_[id];
84 }
85 
GetBuffer(HloInstruction * instruction,const ShapeIndex & index) const86 LogicalBuffer& LogicalBufferAnalysis::GetBuffer(HloInstruction* instruction,
87                                                 const ShapeIndex& index) const {
88   return *output_buffers_.at(std::make_pair(instruction, index));
89 }
90 
NewLogicalBuffer(HloInstruction * instruction,const ShapeIndex & index)91 void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction,
92                                              const ShapeIndex& index) {
93   CHECK_EQ(logical_buffers_.size(), next_buffer_id_);
94   logical_buffers_.emplace_back(
95       absl::make_unique<LogicalBuffer>(instruction, index, next_buffer_id_));
96   output_buffers_[std::make_pair(instruction, index)] =
97       logical_buffers_.back().get();
98 
99   ++next_buffer_id_;
100 }
101 
DefaultAction(HloInstruction * hlo_instruction)102 Status LogicalBufferAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
103   // Create a logical buffer for each output of the instruction.
104   ShapeUtil::ForEachSubshape(
105       hlo_instruction->shape(),
106       [this, hlo_instruction](const Shape& shape, const ShapeIndex& index) {
107         NewLogicalBuffer(hlo_instruction, index);
108       });
109 
110   return Status::OK();
111 }
112 
HandleGetTupleElement(HloInstruction *)113 Status LogicalBufferAnalysis::HandleGetTupleElement(HloInstruction*) {
114   // GetTupleElement does not create buffers.
115   return Status::OK();
116 }
117 
HandleAddDependency(HloInstruction * add_dependency)118 Status LogicalBufferAnalysis::HandleAddDependency(
119     HloInstruction* add_dependency) {
120   // AddDependency just forwards the value of its zero-th operand and does not
121   // create buffers.
122   return Status::OK();
123 }
124 
HandleCopy(HloInstruction * copy)125 Status LogicalBufferAnalysis::HandleCopy(HloInstruction* copy) {
126   // The top-level buffer (index={}) for kCopy is newly created, but all other
127   // buffers (in the case of a tuple shape) come from the operand
128   NewLogicalBuffer(copy, /*index=*/{});
129   return Status::OK();
130 }
131 
HandleBitcast(HloInstruction *)132 Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) {
133   // A kBitcast instruction aliases its operand. That is, the buffer of its
134   // result *is* the buffer of its operand.
135   return Status::OK();
136 }
137 
HandleDomain(HloInstruction *)138 Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) {
139   // A kDomain instruction aliases its operand. That is, the buffer of its
140   // result *is* the buffer of its operand.
141   return Status::OK();
142 }
143 
HandleRecvDone(HloInstruction * recv_done)144 Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction* recv_done) {
145   // RecvDone produces a two-element tuple containing the data value (which
146   // aliases part of its operand) and a token. Only the tuple index table and
147   // the token are defined by the RecvDone.
148   NewLogicalBuffer(recv_done, /*index=*/{});
149   NewLogicalBuffer(recv_done, /*index=*/{1});
150   return Status::OK();
151 }
152 
HandleSend(HloInstruction * send)153 Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
154   // Send creates new buffers for the top-level tuple, the context (tuple
155   // element at {1}), and the token (tuple element at {2}). Tuple element at {0}
156   // is an alias of the Send operand, so we don't need to create a new Logical
157   // Buffer for that.
158   NewLogicalBuffer(send, /*index=*/{});
159   NewLogicalBuffer(send, /*index=*/{1});
160   NewLogicalBuffer(send, /*index=*/{2});
161   return Status::OK();
162 }
163 
HandleCopyStart(HloInstruction * copy_start)164 Status LogicalBufferAnalysis::HandleCopyStart(HloInstruction* copy_start) {
165   // CopyStart defines the tuple, target buffer at index {0}, and context at
166   // index {2}.
167   NewLogicalBuffer(copy_start, /*index=*/{});
168   NewLogicalBuffer(copy_start, /*index=*/{0});
169   NewLogicalBuffer(copy_start, /*index=*/{2});
170   return Status::OK();
171 }
172 
HandleCopyDone(HloInstruction * copy_done)173 Status LogicalBufferAnalysis::HandleCopyDone(HloInstruction* copy_done) {
174   // The output of CopyDone aliases with operand {0}. CopyDone doesn't create
175   // any buffers.
176   return Status::OK();
177 }
178 
HandleTuple(HloInstruction * tuple)179 Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) {
180   // A Tuple instruction only creates the top-level buffer.
181   NewLogicalBuffer(tuple, /*index=*/{});
182   return Status::OK();
183 }
184 
HandleTupleSelect(HloInstruction * tuple_select)185 Status LogicalBufferAnalysis::HandleTupleSelect(HloInstruction* tuple_select) {
186   // Select allocates a new buffer and then shallow copies the on_true or
187   // on_false buffer into this new buffer.
188   NewLogicalBuffer(tuple_select, /*index=*/{});
189   return Status::OK();
190 }
191 
HandleCustomCall(HloInstruction * custom_call)192 Status LogicalBufferAnalysis::HandleCustomCall(HloInstruction* custom_call) {
193   auto ccall = Cast<HloCustomCallInstruction>(custom_call);
194   absl::flat_hash_set<ShapeIndex> aliased_outputs;
195   for (const auto& pair : ccall->output_to_operand_aliasing()) {
196     aliased_outputs.insert(pair.first);
197   }
198   ShapeUtil::ForEachSubshape(ccall->shape(),
199                              [&](const Shape& shape, const ShapeIndex& index) {
200                                if (!aliased_outputs.contains(index)) {
201                                  NewLogicalBuffer(custom_call, index);
202                                }
203                              });
204   return Status::OK();
205 }
206 
207 }  // namespace xla
208