1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/core/platform/logging.h"
24
25 namespace xla {
26
FusionNodeIndexingEvaluation(const HloInstruction * fusion,int64_t root_usage_count)27 FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
28 const HloInstruction* fusion, int64_t root_usage_count)
29 : fusion_(fusion) {
30 HloInstruction* root = fusion->fused_expression_root();
31 indexing_users_[root].insert(fusion);
32 index_usage_count_[fusion] = root_usage_count;
33 RecomputeCache();
34 }
35
36 // This constant is arbitrarily chosen. Essentially we don't want to have too
37 // much code duplication, because it slows down the compilation time. There is
38 // a tradeoff between compilation time and runtime here.
39 const int64_t FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15;
40
41 namespace {
42
43 // Returns which ops invalidate the cache of emitted instructions by creating a
44 // new BasicBlock and setting the insertion point to the newly created
45 // BasicBlock. We can only reuse cached values if they were emitted in the same
46 // BasicBlock as the current BasicBlock.
OpInvalidatesCache(const HloInstruction * hlo)47 bool OpInvalidatesCache(const HloInstruction* hlo) {
48 switch (hlo->opcode()) {
49 // This list of ops was created by inspecting the code. There is no
50 // guarantee that it is complete.
51 case HloOpcode::kConcatenate:
52 case HloOpcode::kDot:
53 case HloOpcode::kDynamicUpdateSlice:
54 case HloOpcode::kPad:
55 case HloOpcode::kReduce:
56 case HloOpcode::kReduceWindow:
57 return true;
58 default:
59 return false;
60 }
61 }
62
63 // Counts the number of "real" users of 'hlo'. When 'hlo' has a fusion node as
64 // user, we consider the users of the fusion parameter corresponding to 'hlo' as
65 // the real users.
UserCount(const HloInstruction * hlo)66 int64_t UserCount(const HloInstruction* hlo) {
67 int64_t cnt = 0;
68 for (HloInstruction* user : hlo->users()) {
69 if (user->opcode() == HloOpcode::kFusion) {
70 // Count the number of users of the parameter corresponding to the fusion
71 // operand.
72 int64_t operand_index = user->operand_index(hlo);
73 cnt += user->fused_parameter(operand_index)->user_count();
74 } else {
75 ++cnt;
76 }
77 }
78 return cnt;
79 }
80 } // namespace
81
CodeDuplicationTooHigh(const HloInstruction * producer) const82 bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
83 const HloInstruction* producer) const {
84 int64_t emitted_instructions = EvaluateEmittedInstructions(producer);
85 return emitted_instructions > kAllowedCodeDuplication ||
86 (OpInvalidatesCache(producer) &&
87 (emitted_instructions > 1 || UserCount(producer) > 1));
88 }
89
MaxCodeDuplicationTooHigh() const90 bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const {
91 for (const auto& entry : index_usage_count_) {
92 if (entry.second > kAllowedCodeDuplication ||
93 (OpInvalidatesCache(entry.first) &&
94 (entry.second > 1 || UserCount(entry.first) > 1))) {
95 return true;
96 }
97 }
98 return false;
99 }
100
EvaluateEmittedInstructions(const HloInstruction * producer) const101 int64_t FusionNodeIndexingEvaluation::EvaluateEmittedInstructions(
102 const HloInstruction* producer) const {
103 int64_t total = 0;
104 for (const auto* user : indexing_users_.at(producer)) {
105 total += index_usage_count_.at(user);
106 }
107 return total;
108 }
109
UpdateEvaluationCache(const HloInstruction * producer,absl::flat_hash_set<const HloInstruction * > indexing_users_of_producer)110 void FusionNodeIndexingEvaluation::UpdateEvaluationCache(
111 const HloInstruction* producer,
112 absl::flat_hash_set<const HloInstruction*> indexing_users_of_producer) {
113 CHECK(!indexing_users_.contains(producer));
114 indexing_users_[producer] = std::move(indexing_users_of_producer);
115 UpdateIndexUsageCount(producer);
116 UpdateIndexingUsersOfOperands(producer);
117 }
118
119 absl::flat_hash_set<const HloInstruction*>
RemoveFusionOperand(HloInstruction * fusion_operand)120 FusionNodeIndexingEvaluation::RemoveFusionOperand(
121 HloInstruction* fusion_operand) {
122 auto indexing_users_of_operand =
123 std::move(indexing_users_.at(fusion_operand));
124 indexing_users_.erase(fusion_operand);
125 CHECK(!index_usage_count_.contains(fusion_operand));
126 return indexing_users_of_operand;
127 }
128
RecomputeCache()129 void FusionNodeIndexingEvaluation::RecomputeCache() {
130 auto postorder =
131 fusion_->fused_instructions_computation()->MakeInstructionPostOrder();
132 std::reverse(postorder.begin(), postorder.end());
133 for (const auto* instruction : postorder) {
134 if (instruction->opcode() == HloOpcode::kParameter) {
135 continue;
136 }
137 UpdateIndexUsageCount(instruction);
138 UpdateIndexingUsersOfOperands(instruction);
139 }
140 }
141
UpdateIndexUsageCount(const HloInstruction * instruction)142 void FusionNodeIndexingEvaluation::UpdateIndexUsageCount(
143 const HloInstruction* instruction) {
144 int64_t total = 0;
145 for (const auto* user : indexing_users_[instruction]) {
146 total += index_usage_count_.at(user);
147 }
148 CHECK(index_usage_count_.emplace(instruction, total).second);
149 }
150
UpdateIndexingUsersOfOperands(const HloInstruction * instruction)151 void FusionNodeIndexingEvaluation::UpdateIndexingUsersOfOperands(
152 const HloInstruction* instruction) {
153 for (const auto* operand : instruction->operands()) {
154 if (operand->opcode() == HloOpcode::kParameter) {
155 // Although actually the parameter gets indexed, we store it as indexing
156 // of the corresponding fusion operand instead because parameter
157 // instruction pointers can be invalidated when we fuse another
158 // instruction into 'fusion_'.
159 operand = fusion_->operand(operand->parameter_number());
160 }
161 // For simplicity we assume that all shape and layout changing
162 // operations except Transposes invalidate index reuse. Transposes are
163 // special: although they are shape changing, we can reuse the
164 // multi-dimensional index for the operand by permuting it.
165 if (instruction->opcode() == HloOpcode::kTranspose ||
166 Shape::Equal().IgnoreElementType()(operand->shape(),
167 instruction->shape())) {
168 // If the index is reused, it means the operand gets index values
169 // from the same set of (indirect) users as 'instruction' itself.
170 indexing_users_[operand].insert(indexing_users_[instruction].begin(),
171 indexing_users_[instruction].end());
172 } else {
173 // If the index is not reused, it means 'instruction' computes a
174 // new index derived from the index it gets.
175 indexing_users_[operand].insert(instruction);
176 }
177 }
178 }
179
180 } // namespace xla
181