1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Analysis for determining the possible set of values for all positions 17 // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped 18 // tracking values across computation boundaries. 19 20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 21 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 22 23 #include <memory> 24 #include <string> 25 #include <unordered_map> 26 #include <vector> 27 28 #include "absl/types/span.h" 29 #include "tensorflow/compiler/xla/service/call_graph.h" 30 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 31 #include "tensorflow/compiler/xla/service/hlo_module.h" 32 #include "tensorflow/compiler/xla/service/hlo_value.h" 33 #include "tensorflow/compiler/xla/shape_util.h" 34 #include "tensorflow/compiler/xla/status.h" 35 #include "tensorflow/compiler/xla/statusor.h" 36 #include "tensorflow/compiler/xla/types.h" 37 #include "tensorflow/compiler/xla/xla_data.pb.h" 38 #include "tensorflow/core/platform/macros.h" 39 40 namespace xla { 41 42 // Analysis which identifies all HLO values and their uses in an HLO module. 43 class HloDataflowAnalysis { 44 public: 45 // Different backends can have very different ways to do fusion, so we give 46 // backends the flexibility to decide whether an fusion instruction can share 47 // buffer with it's operands. If this is not specified, a default strategy 48 // will be used; if this is specified, it will be applied *in addition* to the 49 // default strategy. 50 // 51 // The first parameter of the function should be the fusion instruction, the 52 // second parameter should be an operand of the fusion instruction. 53 // 54 // TODO(b/80315712): Find a better way to tell whether a fusion can share 55 // buffer. 56 using FusionCanShareBufferFunction = std::function<bool( 57 const HloInstruction* fusion, const HloInstruction* operand)>; 58 59 // Run dataflow analysis on the given module. Parameters: 60 // 61 // ssa_form : If true then new values are defined at the merge points of 62 // kWhile instructions. Abusing nomenclature somewhat, we call these "phi 63 // values". The merge is formed by the init value and loop backedge. The 64 // SSA form is minimal in that a new phi value is defined only if the 65 // merge point is reachable by multiple different values. The SSA form is 66 // also in loop-closed form in that no values defined inside of a loop 67 // (while body) is used outside of the loop. 68 // 69 // If ssa_form is false, then merge points do not define new 70 // values. Rather, the HloValueSet for the merge point contains the union 71 // of the merged HloValues. 72 // 73 // bitcast_defines_value : If true then the Bitcast HLO instruction defines 74 // a new HLO value in the analysis. If false then Bitcast forwards the 75 // value of its operand. 76 static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run( 77 const HloModule& module, bool ssa_form = false, 78 bool bitcast_defines_value = false, 79 const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr); 80 81 static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst); 82 83 // Returns true if 'instruction' defines an HLO value at the given shape index 84 // of its output. 85 bool ValueIsDefinedAt(const HloInstruction* instruction, 86 const ShapeIndex& index = {}) const; 87 88 // Return the HloValue defined by 'instruction' at the given shape index of 89 // its output. 90 // 91 // Precondition: ValueIsDefinedAt is true for this instruction and index. 92 const HloValue& GetValueDefinedAt(const HloInstruction* instruction, 93 const ShapeIndex& index = {}) const; 94 HloValue& GetValueDefinedAt(const HloInstruction* instruction, 95 const ShapeIndex& index = {}); 96 97 // Return the InstructionValueSet for the given instruction. 98 const InstructionValueSet& GetInstructionValueSet( 99 const HloInstruction* instruction) const; 100 InstructionValueSet& GetInstructionValueSet( 101 const HloInstruction* instruction); 102 103 // Return the HloValueSet for the given instruction at the given index or the 104 // given position. 105 const HloValueSet& GetValueSet(const HloInstruction* instruction, 106 const ShapeIndex& index = {}) const; 107 const HloValueSet& GetValueSet(const HloPosition& position) const; 108 HloValueSet& GetValueSet(const HloPosition& position); 109 HloValueSet& GetValueSet(const HloInstruction* instruction, 110 const ShapeIndex& index = {}); 111 112 // Return the unique value in the HloValueSet at the given instruction and 113 // shape index. CHECKs if the value set does not contain a exactly one value. 114 const HloValue& GetUniqueValueAt(const HloInstruction* instruction, 115 const ShapeIndex& index = {}) const { 116 return GetValueSet(instruction, index).GetUniqueValue(); 117 } 118 HloValue& GetUniqueValueAt(const HloInstruction* instruction, 119 const ShapeIndex& index = {}) { 120 return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); 121 } 122 123 // Return the HloValue with the given Id. 124 const HloValue& GetValue(HloValue::Id value_id) const; 125 HloValue& GetValue(HloValue::Id value_id); 126 127 // Return the total number of HloValues. value_count()128 int64 value_count() const { return values_.size(); } 129 130 // Return a vector of all HloValues stabily sorted by HloValue::Id. values()131 const std::vector<const HloValue*>& values() const { return values_vector_; } 132 133 // Return the call graph used for computing the dataflow. call_graph()134 const CallGraph& call_graph() const { return *call_graph_; } 135 136 string ToString() const; 137 138 // Returns true if 'user' cannot possibly use the buffer at 'index' in 139 // 'operand'. Returns false otherwise. 140 // 141 // 'operand' does not have to be an operand of 'user'. This can be the case 142 // with indirect uses. 143 bool DoesNotUseOperandBuffer(const HloInstruction* operand, 144 const ShapeIndex& index, 145 const HloInstruction* user) const; 146 147 // Returns true if 'user' (at 'user_index') can share a buffer with its 148 // operand 'operand' (at 'operand_index'). Returns false otherwise. 149 // 150 // REQUIRES: 'operand' is an operand of 'user'. 151 bool CanShareOperandBufferWithUser(HloInstruction* operand, 152 const ShapeIndex& operand_index, 153 HloInstruction* user, 154 const ShapeIndex& user_index) const; 155 156 protected: 157 HloDataflowAnalysis( 158 const HloModule& module, bool ssa_form, 159 bool bitcast_defines_value = false, 160 const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr); 161 162 // Returns a new HloValue defined at the given instruction and shape index. 163 HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, 164 bool is_phi = false); 165 166 // Mark the HloValue with the given ID for deletion. 167 void MarkValueForDeletion(HloValue::Id value_id); 168 169 // Delete all HloValues marked for deletion. Should be called after 170 // propagation is complete. 171 void DeleteMarkedValues(); 172 173 // Constructs and initializes the InstructionValueSets of all instructions to 174 // contain exactly the HloValues defined by each instruction. These values can 175 // then propagated throughout the HLO graph by calling Propagate. 176 Status InitializeInstructionValueSets(); 177 178 // Updates the value set of the given instruction based on the values flowing 179 // into the instruction (operands and cross-computation dataflow). 180 bool UpdateInstructionValueSet(HloInstruction* instruction); 181 182 // Updates the value set for a particular instruction type. Returns whether 183 // the instruction value set changed. 184 bool UpdateBitcastValueSet(HloInstruction* bitcast); 185 bool UpdateCallValueSet(HloInstruction* call); 186 bool UpdateConditionalValueSet(HloInstruction* conditional); 187 bool UpdateCopyValueSet(HloInstruction* copy); 188 bool UpdateDomainValueSet(HloInstruction* domain); 189 bool UpdateGetTupleElementValueSet(HloInstruction* gte); 190 bool UpdateParameterValueSet(HloInstruction* parameter); 191 bool UpdateRecvDoneValueSet(HloInstruction* recv_done); 192 bool UpdateTupleSelectValueSet(HloInstruction* select); 193 bool UpdateSendValueSet(HloInstruction* send); 194 bool UpdateTupleValueSet(HloInstruction* tuple); 195 bool UpdateWhileValueSet(HloInstruction* xla_while); 196 bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); 197 198 // Propagate the dataflow through the module. 199 void Propagate(); 200 201 // Return the result of the SSA Phi function applied to the given inputs at 202 // the given instruction. If skip_top_level is true, then the top level of the 203 // value set of 'instruction' is not modified. 204 bool Phi(HloInstruction* instruction, 205 absl::Span<const InstructionValueSet* const> inputs); 206 207 // Updates the positions of the HloValues in the output of the given 208 // instruction. This should be called after the instruction value set of 209 // 'instruction' has been changed. 'prev_value_set' must point to the previous 210 // state of the value set prior to the change. 'prev_value_set' may be null if 211 // this is the first time positions are being computed. The previous state is 212 // necessary to efficiently remove positions which have been eliminated due to 213 // changes in the instructions' InstructionValueSet. 214 void UpdatePositionsOfValuesAt( 215 HloInstruction* instruction, const InstructionValueSet& new_value_set, 216 const InstructionValueSet* prev_value_set = nullptr); 217 218 // Verify various invariants of the dataflow analysis. 219 Status Verify() const; 220 221 const HloModule& module_; 222 const bool ssa_form_; 223 const bool bitcast_defines_value_; 224 225 std::unique_ptr<CallGraph> call_graph_; 226 227 // The map of all HloValues in the module. We pass around pointers to the 228 // mapped HloValues, so the underlying container must keep them valid despite 229 // mutations touching other map entries. 230 std::unordered_map<HloValue::Id, HloValue> values_; 231 232 // A map from instruction to InstructionValueSet. 233 std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_; 234 235 // Values marked for deletion during construction. We don't delete them 236 // immediately because references to them may remain in ValueSets temporarily 237 // during propagation. After construction, these values are deleted. 238 std::vector<HloValue::Id> value_ids_to_delete_; 239 240 // A vector containing all HloValues sorted by HloValue::Id. 241 std::vector<const HloValue*> values_vector_; 242 243 // The Id to use for the next HloValue. 244 HloValue::Id next_value_id_ = 0; 245 246 // Backend specific function that decides whether a fusion can share buffer 247 // with its operand. 248 FusionCanShareBufferFunction fusion_can_share_buffer_ = nullptr; 249 }; 250 251 } // namespace xla 252 253 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 254