1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_NODE_INDEXING_EVALUATION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_NODE_INDEXING_EVALUATION_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/container/flat_hash_set.h" 21 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 22 #include "tensorflow/compiler/xla/types.h" 23 24 namespace xla { 25 class FusionNodeIndexingEvaluation { 26 public: 27 explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion, 28 int64_t root_usage_count = 1); 29 30 // Evaluate the number of times 'producer' would be emitted if it is fused 31 // into 'fusion_'. If the duplication is "too high" (some arbitrary chosen 32 // constant), returns true. 33 bool CodeDuplicationTooHigh(const HloInstruction* producer) const; 34 35 // Evaluate the maximum code duplication inside the fusion node. If the 36 // maximum code duplication is "too high" (some arbitrary chosen constant), 37 // returns true. 38 bool MaxCodeDuplicationTooHigh() const; 39 40 // Evaluate the number of times 'producer' would be emitted if it is fused 41 // into 'fusion_'. 42 int64_t EvaluateEmittedInstructions(const HloInstruction* producer) const; 43 44 // Update the evaluation cache after having fused 'producer' into 'fusion_'. 45 // 'producer' is the cloned instruction which is now part of the fusion 46 // computation. 'indexing_users_of_producer' are the direct or indirect users 47 // of 'producer' which pass index values created by them. 48 void UpdateEvaluationCache( 49 const HloInstruction* producer, 50 absl::flat_hash_set<const HloInstruction*> indexing_users_of_producer); 51 52 // Prior to fusing, we need to erase the indexing_users_ entry of the 53 // producer to be fused, because the HloInstruction pointer will be 54 // invalidated. We return the set of direct or indirect users which pass index 55 // values created by them to the fusion parameter corresponding to this 56 // producer. This will be needed for updating the evaluation cache (see 57 // UpdateEvaluationCache). 58 absl::flat_hash_set<const HloInstruction*> RemoveFusionOperand( 59 HloInstruction* fusion_operand); 60 61 private: 62 static const int64_t kAllowedCodeDuplication; 63 64 // Computes the 'indexing_users_' and 'index_usage_count_' maps based on the 65 // current instructions inside the fusion node. Also updates 66 // 'total_emitted_instructions_' accordingly. 67 void RecomputeCache(); 68 69 // Computes the 'index_usage_count_' entry for 'instruction'. 70 void UpdateIndexUsageCount(const HloInstruction* instruction); 71 72 // Updates the 'indexing_users_' entry of the operands of 'instruction'. 73 void UpdateIndexingUsersOfOperands(const HloInstruction* instruction); 74 75 // Collects for each instruction in a fusion node from which direct or 76 // indirect users newly created index values are passed. Roughly speaking, we 77 // reuse index values if the shapes are equal when ignoring the element type 78 // (we may reuse also if the shape change is a bitcast, but we don't consider 79 // that here). By ignoring potential reuses our estimate of which instruction 80 // generates a new index value is a bit more conservative than necessary. 81 absl::flat_hash_map<const HloInstruction*, 82 absl::flat_hash_set<const HloInstruction*>> 83 indexing_users_; 84 85 // Stores the number of different index accesses for each instruction in a 86 // fusion node. The fusion emitter caches access with the same index, so this 87 // value indicates how many times a specific instruction will be emitted. 88 absl::flat_hash_map<const HloInstruction*, int64_t> index_usage_count_; 89 90 // The fusion instruction. 91 const HloInstruction* fusion_; 92 }; 93 } // namespace xla 94 95 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_NODE_INDEXING_EVALUATION_H_ 96