1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/logical_buffer_analysis.h"
17
18 #include <utility>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/logging.h"
24
25 namespace xla {
26
27 namespace {
28
29 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)30 void GatherFusionInstructions(
31 HloInstruction* instruction,
32 std::vector<HloInstruction*>* fusion_instructions) {
33 CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
34 for (auto* fused : instruction->fused_instructions()) {
35 if (fused->opcode() == HloOpcode::kFusion) {
36 GatherFusionInstructions(fused, fusion_instructions);
37 }
38 }
39 fusion_instructions->push_back(instruction);
40 }
41
42 } // namespace
43
44 /* static */ StatusOr<std::unique_ptr<LogicalBufferAnalysis>>
Run(const HloModule * module)45 LogicalBufferAnalysis::Run(const HloModule* module) {
46 std::unique_ptr<LogicalBufferAnalysis> analysis(
47 new LogicalBufferAnalysis(module));
48 TF_RETURN_IF_ERROR(analysis->Analyze());
49 return std::move(analysis);
50 }
51
Analyze()52 Status LogicalBufferAnalysis::Analyze() {
53 // Empirically we usually have a few more logical buffers than instructions,
54 // so reserve 10% more than the number of instructions to avoid frequent
55 // resizes.
56 logical_buffers_.clear();
57 logical_buffers_.reserve((module_->instruction_count() * 11) / 10);
58
59 // We filter out fusion computations, and get to them through fusion
60 // instructions. This is because it's possible to have orphaned (unreachable)
61 // fusion computations, and we don't want to try to assign buffers to those.
62 std::vector<HloInstruction*> fusion_instructions;
63 for (auto* computation : module_->MakeNonfusionComputations()) {
64 TF_RETURN_IF_ERROR(computation->Accept(this));
65 for (auto* instruction : computation->instructions()) {
66 if (instruction->opcode() != HloOpcode::kFusion) {
67 continue;
68 }
69 GatherFusionInstructions(instruction, &fusion_instructions);
70 }
71 }
72 for (auto* instruction : fusion_instructions) {
73 TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
74 }
75 return Status::OK();
76 }
77
GetBuffer(LogicalBuffer::Id id) const78 LogicalBuffer& LogicalBufferAnalysis::GetBuffer(LogicalBuffer::Id id) const {
79 CHECK_GE(id, 0);
80 CHECK_LT(id, logical_buffers_.size());
81 return *logical_buffers_[id];
82 }
83
GetBuffer(HloInstruction * instruction,const ShapeIndex & index) const84 LogicalBuffer& LogicalBufferAnalysis::GetBuffer(HloInstruction* instruction,
85 const ShapeIndex& index) const {
86 return *output_buffers_.at(std::make_pair(instruction, index));
87 }
88
NewLogicalBuffer(HloInstruction * instruction,const ShapeIndex & index)89 void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction,
90 const ShapeIndex& index) {
91 CHECK_EQ(logical_buffers_.size(), next_buffer_id_);
92 logical_buffers_.emplace_back(
93 absl::make_unique<LogicalBuffer>(instruction, index, next_buffer_id_));
94 output_buffers_[std::make_pair(instruction, index)] =
95 logical_buffers_.back().get();
96
97 ++next_buffer_id_;
98 }
99
DefaultAction(HloInstruction * hlo_instruction)100 Status LogicalBufferAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
101 // Create a logical buffer for each output of the instruction.
102 ShapeUtil::ForEachSubshape(
103 hlo_instruction->shape(),
104 [this, hlo_instruction](const Shape& shape, const ShapeIndex& index) {
105 NewLogicalBuffer(hlo_instruction, index);
106 });
107
108 return Status::OK();
109 }
110
HandleGetTupleElement(HloInstruction *)111 Status LogicalBufferAnalysis::HandleGetTupleElement(HloInstruction*) {
112 // GetTupleElement does not create buffers.
113 return Status::OK();
114 }
115
HandleAddDependency(HloInstruction * add_dependency)116 Status LogicalBufferAnalysis::HandleAddDependency(
117 HloInstruction* add_dependency) {
118 // AddDependency just forwards the value of its zero-th operand and does not
119 // create buffers.
120 return Status::OK();
121 }
122
HandleCopy(HloInstruction * copy)123 Status LogicalBufferAnalysis::HandleCopy(HloInstruction* copy) {
124 // The top-level buffer (index={}) for kCopy is newly created, but all other
125 // buffers (in the case of a tuple shape) come from the operand
126 NewLogicalBuffer(copy, /*index=*/{});
127 return Status::OK();
128 }
129
HandleBitcast(HloInstruction *)130 Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) {
131 // A kBitcast instruction aliases its operand. That is, the buffer of its
132 // result *is* the buffer of its operand.
133 return Status::OK();
134 }
135
HandleDomain(HloInstruction *)136 Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) {
137 // A kDomain instruction aliases its operand. That is, the buffer of its
138 // result *is* the buffer of its operand.
139 return Status::OK();
140 }
141
HandleRecvDone(HloInstruction * recv_done)142 Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction* recv_done) {
143 // RecvDone produces a two-element tuple containing the data value (which
144 // aliases part of its operand) and a token. Only the tuple index table and
145 // the token are defined by the RecvDone.
146 NewLogicalBuffer(recv_done, /*index=*/{});
147 NewLogicalBuffer(recv_done, /*index=*/{1});
148 return Status::OK();
149 }
150
HandleSend(HloInstruction * send)151 Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
152 // Send creates new buffers for the top-level tuple, the context (tuple
153 // element at {1}), and the token (tuple element at {2}). Tuple element at {0}
154 // is an alias of the Send operand, so we don't need to create a new Logical
155 // Buffer for that.
156 NewLogicalBuffer(send, /*index=*/{});
157 NewLogicalBuffer(send, /*index=*/{1});
158 NewLogicalBuffer(send, /*index=*/{2});
159 return Status::OK();
160 }
161
HandleTuple(HloInstruction * tuple)162 Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) {
163 // A Tuple instruction only creates the top-level buffer.
164 NewLogicalBuffer(tuple, /*index=*/{});
165 return Status::OK();
166 }
167
HandleTupleSelect(HloInstruction * tuple_select)168 Status LogicalBufferAnalysis::HandleTupleSelect(HloInstruction* tuple_select) {
169 // Select allocates a new buffer and then shallow copies the on_true or
170 // on_false buffer into this new buffer.
171 NewLogicalBuffer(tuple_select, /*index=*/{});
172 return Status::OK();
173 }
174
175 } // namespace xla
176