• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/core/platform/logging.h"
24 
25 namespace xla {
26 
FusionNodeIndexingEvaluation(const HloInstruction * fusion,int64_t root_usage_count)27 FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
28     const HloInstruction* fusion, int64_t root_usage_count)
29     : fusion_(fusion) {
30   HloInstruction* root = fusion->fused_expression_root();
31   indexing_users_[root].insert(fusion);
32   index_usage_count_[fusion] = root_usage_count;
33   RecomputeCache();
34 }
35 
36 // This constant is arbitrarily chosen. Essentially we don't want to have too
37 // much code duplication, because it slows down the compilation time. There is
38 // a tradeoff between compilation time and runtime here.
39 const int64_t FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15;
40 
41 namespace {
42 
43 // Returns which ops invalidate the cache of emitted instructions by creating a
44 // new BasicBlock and setting the insertion point to the newly created
45 // BasicBlock. We can only reuse cached values if they were emitted in the same
46 // BasicBlock as the current BasicBlock.
OpInvalidatesCache(const HloInstruction * hlo)47 bool OpInvalidatesCache(const HloInstruction* hlo) {
48   switch (hlo->opcode()) {
49     // This list of ops was created by inspecting the code. There is no
50     // guarantee that it is complete.
51     case HloOpcode::kConcatenate:
52     case HloOpcode::kDot:
53     case HloOpcode::kDynamicUpdateSlice:
54     case HloOpcode::kPad:
55     case HloOpcode::kReduce:
56     case HloOpcode::kReduceWindow:
57       return true;
58     default:
59       return false;
60   }
61 }
62 
63 // Counts the number of "real" users of 'hlo'. When 'hlo' has a fusion node as
64 // user, we consider the users of the fusion parameter corresponding to 'hlo' as
65 // the real users.
UserCount(const HloInstruction * hlo)66 int64_t UserCount(const HloInstruction* hlo) {
67   int64_t cnt = 0;
68   for (HloInstruction* user : hlo->users()) {
69     if (user->opcode() == HloOpcode::kFusion) {
70       // Count the number of users of the parameter corresponding to the fusion
71       // operand.
72       int64_t operand_index = user->operand_index(hlo);
73       cnt += user->fused_parameter(operand_index)->user_count();
74     } else {
75       ++cnt;
76     }
77   }
78   return cnt;
79 }
80 }  // namespace
81 
CodeDuplicationTooHigh(const HloInstruction * producer) const82 bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
83     const HloInstruction* producer) const {
84   int64_t emitted_instructions = EvaluateEmittedInstructions(producer);
85   return emitted_instructions > kAllowedCodeDuplication ||
86          (OpInvalidatesCache(producer) &&
87           (emitted_instructions > 1 || UserCount(producer) > 1));
88 }
89 
MaxCodeDuplicationTooHigh() const90 bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const {
91   for (const auto& entry : index_usage_count_) {
92     if (entry.second > kAllowedCodeDuplication ||
93         (OpInvalidatesCache(entry.first) &&
94          (entry.second > 1 || UserCount(entry.first) > 1))) {
95       return true;
96     }
97   }
98   return false;
99 }
100 
EvaluateEmittedInstructions(const HloInstruction * producer) const101 int64_t FusionNodeIndexingEvaluation::EvaluateEmittedInstructions(
102     const HloInstruction* producer) const {
103   int64_t total = 0;
104   for (const auto* user : indexing_users_.at(producer)) {
105     total += index_usage_count_.at(user);
106   }
107   return total;
108 }
109 
UpdateEvaluationCache(const HloInstruction * producer,absl::flat_hash_set<const HloInstruction * > indexing_users_of_producer)110 void FusionNodeIndexingEvaluation::UpdateEvaluationCache(
111     const HloInstruction* producer,
112     absl::flat_hash_set<const HloInstruction*> indexing_users_of_producer) {
113   CHECK(!indexing_users_.contains(producer));
114   indexing_users_[producer] = std::move(indexing_users_of_producer);
115   UpdateIndexUsageCount(producer);
116   UpdateIndexingUsersOfOperands(producer);
117 }
118 
119 absl::flat_hash_set<const HloInstruction*>
RemoveFusionOperand(HloInstruction * fusion_operand)120 FusionNodeIndexingEvaluation::RemoveFusionOperand(
121     HloInstruction* fusion_operand) {
122   auto indexing_users_of_operand =
123       std::move(indexing_users_.at(fusion_operand));
124   indexing_users_.erase(fusion_operand);
125   CHECK(!index_usage_count_.contains(fusion_operand));
126   return indexing_users_of_operand;
127 }
128 
RecomputeCache()129 void FusionNodeIndexingEvaluation::RecomputeCache() {
130   auto postorder =
131       fusion_->fused_instructions_computation()->MakeInstructionPostOrder();
132   std::reverse(postorder.begin(), postorder.end());
133   for (const auto* instruction : postorder) {
134     if (instruction->opcode() == HloOpcode::kParameter) {
135       continue;
136     }
137     UpdateIndexUsageCount(instruction);
138     UpdateIndexingUsersOfOperands(instruction);
139   }
140 }
141 
UpdateIndexUsageCount(const HloInstruction * instruction)142 void FusionNodeIndexingEvaluation::UpdateIndexUsageCount(
143     const HloInstruction* instruction) {
144   int64_t total = 0;
145   for (const auto* user : indexing_users_[instruction]) {
146     total += index_usage_count_.at(user);
147   }
148   CHECK(index_usage_count_.emplace(instruction, total).second);
149 }
150 
UpdateIndexingUsersOfOperands(const HloInstruction * instruction)151 void FusionNodeIndexingEvaluation::UpdateIndexingUsersOfOperands(
152     const HloInstruction* instruction) {
153   for (const auto* operand : instruction->operands()) {
154     if (operand->opcode() == HloOpcode::kParameter) {
155       // Although actually the parameter gets indexed, we store it as indexing
156       // of the corresponding fusion operand instead because parameter
157       // instruction pointers can be invalidated when we fuse another
158       // instruction into 'fusion_'.
159       operand = fusion_->operand(operand->parameter_number());
160     }
161     // For simplicity we assume that all shape and layout changing
162     // operations except Transposes invalidate index reuse. Transposes are
163     // special: although they are shape changing, we can reuse the
164     // multi-dimensional index for the operand by permuting it.
165     if (instruction->opcode() == HloOpcode::kTranspose ||
166         Shape::Equal().IgnoreElementType()(operand->shape(),
167                                            instruction->shape())) {
168       // If the index is reused, it means the operand gets index values
169       // from the same set of (indirect) users as 'instruction' itself.
170       indexing_users_[operand].insert(indexing_users_[instruction].begin(),
171                                       indexing_users_[instruction].end());
172     } else {
173       // If the index is not reused, it means 'instruction' computes a
174       // new index derived from the index it gets.
175       indexing_users_[operand].insert(instruction);
176     }
177   }
178 }
179 
180 }  // namespace xla
181