1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Analysis for determining the possible set of values for all positions 17 // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped 18 // tracking values across computation boundaries. 19 20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 21 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 22 23 #include <memory> 24 #include <string> 25 #include <unordered_map> 26 #include <vector> 27 28 #include "absl/types/span.h" 29 #include "tensorflow/compiler/xla/service/call_graph.h" 30 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 31 #include "tensorflow/compiler/xla/service/hlo_module.h" 32 #include "tensorflow/compiler/xla/service/hlo_value.h" 33 #include "tensorflow/compiler/xla/shape_util.h" 34 #include "tensorflow/compiler/xla/status.h" 35 #include "tensorflow/compiler/xla/statusor.h" 36 #include "tensorflow/compiler/xla/types.h" 37 #include "tensorflow/compiler/xla/xla_data.pb.h" 38 #include "tensorflow/core/platform/macros.h" 39 40 namespace xla { 41 42 // Analysis which identifies all HLO values and their uses in an HLO module. 43 class HloDataflowAnalysis { 44 public: 45 // Infrastructure for passing may-alias hints: HLO passes can populate the 46 // may-alias table. If an empty optional is returned, default rules are used. 47 // 48 // The first parameter of the function should be the instruction, the 49 // second parameter should be an operand of the instruction. The third 50 // parameter should be the output index of the instruction. 51 using CanShareBuffer = std::function<absl::optional<bool>( 52 const HloInstruction* instr, const HloInstruction* operand, 53 const ShapeIndex& user_index)>; 54 55 // Runs dataflow analysis on the given module. Parameters: 56 // 57 // ssa_form : If true then new values are defined at the merge points of 58 // kWhile instructions. Abusing nomenclature somewhat, we call these "phi 59 // values". The merge is formed by the init value and loop backedge. The 60 // SSA form is minimal in that a new phi value is defined only if the 61 // merge point is reachable by multiple different values. The SSA form is 62 // also in loop-closed form in that no values defined inside of a loop 63 // (while body) is used outside of the loop. 64 // 65 // If ssa_form is false, then merge points do not define new 66 // values. Rather, the HloValueSet for the merge point contains the union 67 // of the merged HloValues. 68 // 69 // bitcast_defines_value : If true then the Bitcast HLO instruction defines 70 // a new HLO value in the analysis. If false then Bitcast forwards the 71 // value of its operand. 72 static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run( 73 const HloModule& module, bool ssa_form = false, 74 bool bitcast_defines_value = false, 75 const CanShareBuffer& can_share_buffer = nullptr); 76 77 static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst); 78 79 // Returns true if 'instruction' defines an HLO value at the given shape index 80 // of its output. 81 bool ValueIsDefinedAt(const HloInstruction* instruction, 82 const ShapeIndex& index = {}) const; 83 84 // Returns the HloValue defined by 'instruction' at the given shape index of 85 // its output. 86 // 87 // Precondition: ValueIsDefinedAt is true for this instruction and index. 88 const HloValue& GetValueDefinedAt(const HloInstruction* instruction, 89 const ShapeIndex& index = {}) const; 90 HloValue& GetValueDefinedAt(const HloInstruction* instruction, 91 const ShapeIndex& index = {}); 92 93 // Returns the InstructionValueSet for the given instruction. 94 const InstructionValueSet& GetInstructionValueSet( 95 const HloInstruction* instruction) const; 96 InstructionValueSet& GetInstructionValueSet( 97 const HloInstruction* instruction); 98 99 // Returns all values that are contained in the output of this instruction in 100 // a flattened set. 101 HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const; 102 103 // Returns the HloValueSet for the given instruction at the given index or the 104 // given position. 105 const HloValueSet& GetValueSet(const HloInstruction* instruction, 106 const ShapeIndex& index = {}) const; 107 const HloValueSet& GetValueSet(const HloPosition& position) const; 108 HloValueSet& GetValueSet(const HloPosition& position); 109 HloValueSet& GetValueSet(const HloInstruction* instruction, 110 const ShapeIndex& index = {}); 111 112 // Returns the unique value in the HloValueSet at the given instruction and 113 // shape index. CHECKs if the value set does not contain a exactly one value. 114 const HloValue& GetUniqueValueAt(const HloInstruction* instruction, 115 const ShapeIndex& index = {}) const { 116 return GetValueSet(instruction, index).GetUniqueValue(); 117 } 118 HloValue& GetUniqueValueAt(const HloInstruction* instruction, 119 const ShapeIndex& index = {}) { 120 return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); 121 } 122 123 // Returns the HloValue with the given Id. 124 const HloValue& GetValue(HloValue::Id value_id) const; 125 HloValue& GetValue(HloValue::Id value_id); 126 127 // Returns the total number of HloValues. value_count()128 int64 value_count() const { return values_.size(); } 129 130 // Returns a vector of all HloValues stabily sorted by HloValue::Id. values()131 const std::vector<HloValue*>& values() const { return values_vector_; } 132 133 // Returns the call graph used for computing the dataflow. call_graph()134 const CallGraph& call_graph() const { return *call_graph_; } 135 136 string ToString() const; 137 138 // Returns true if 'user' cannot possibly use the buffer at 'index' in 139 // 'operand'. Returns false otherwise. 140 // 141 // 'operand' does not have to be an operand of 'user'. This can be the case 142 // with indirect uses. 143 bool DoesNotUseOperandBuffer(const HloInstruction* operand, 144 const ShapeIndex& index, 145 const HloInstruction* user) const; 146 147 // Returns true if 'user' (at 'user_index') can share a buffer with its 148 // operand 'operand' (at 'operand_index'). Returns false otherwise. 149 // 150 // REQUIRES: 'operand' is an operand of 'user'. 151 bool CanShareOperandBufferWithUser(HloInstruction* operand, 152 const ShapeIndex& operand_index, 153 HloInstruction* user, 154 const ShapeIndex& user_index) const; 155 module()156 const HloModule& module() const { return module_; } 157 158 protected: 159 HloDataflowAnalysis(const HloModule& module, bool ssa_form, 160 bool bitcast_defines_value = false, 161 const CanShareBuffer& can_share_buffer = nullptr); 162 163 // Returns a new HloValue defined at the given instruction and shape index. 164 HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, 165 bool is_phi = false); 166 167 // Marks the HloValue with the given ID for deletion. 168 void MarkValueForDeletion(HloValue::Id value_id); 169 170 // Deletes all HloValues marked for deletion. Should be called after 171 // propagation is complete. 172 void DeleteMarkedValues(); 173 174 // Constructs and initializes the InstructionValueSets of all instructions to 175 // contain exactly the HloValues defined by each instruction. These values can 176 // then propagated throughout the HLO graph by calling Propagate. 177 Status InitializeInstructionValueSets(); 178 179 // Updates the value set of the given instruction based on the values flowing 180 // into the instruction (operands and cross-computation dataflow). 181 bool UpdateInstructionValueSet(HloInstruction* instruction); 182 183 // Updates the value set for a particular instruction type. Returns whether 184 // the instruction value set changed. 185 bool UpdateBitcastValueSet(HloInstruction* bitcast); 186 bool UpdateCallValueSet(HloInstruction* call); 187 bool UpdateConditionalValueSet(HloInstruction* conditional); 188 bool UpdateCopyValueSet(HloInstruction* copy); 189 bool UpdateDomainValueSet(HloInstruction* domain); 190 bool UpdateGetTupleElementValueSet(HloInstruction* gte); 191 bool UpdateParameterValueSet(HloInstruction* parameter); 192 bool UpdateCopyStartValueSet(HloInstruction* copy_start); 193 bool UpdateCopyDoneValueSet(HloInstruction* copy_done); 194 bool UpdateRecvDoneValueSet(HloInstruction* recv_done); 195 bool UpdateTupleSelectValueSet(HloInstruction* select); 196 bool UpdateSendValueSet(HloInstruction* send); 197 bool UpdateSetDimensionSizeValueSet(HloInstruction* set_dimension_size); 198 bool UpdateTupleValueSet(HloInstruction* tuple); 199 bool UpdateWhileValueSet(HloInstruction* xla_while); 200 bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); 201 202 // Propagates the dataflow through the module. In particular, it propagates 203 // the HloValueSet from its defining instruction to the users of the 204 // instructions. 205 void Propagate(); 206 207 // Returns the result of the SSA Phi function applied to the given inputs at 208 // the given instruction. 209 bool Phi(HloInstruction* instruction, 210 absl::Span<const InstructionValueSet* const> inputs); 211 212 // Updates the positions of the HloValues in the output of the given 213 // instruction. This should be called after the instruction value set of 214 // 'instruction' has been changed. 'prev_value_set' must point to the previous 215 // state of the value set prior to the change. 'prev_value_set' may be null if 216 // this is the first time positions are being computed. The previous state is 217 // necessary to efficiently remove positions which have been eliminated due to 218 // changes in the instructions' InstructionValueSet. 219 void UpdatePositionsOfValuesAt( 220 HloInstruction* instruction, const InstructionValueSet& new_value_set, 221 const InstructionValueSet* prev_value_set = nullptr); 222 223 // Verifies various invariants of the dataflow analysis. 224 Status Verify() const; 225 226 const HloModule& module_; 227 const bool ssa_form_; 228 const bool bitcast_defines_value_; 229 230 std::unique_ptr<CallGraph> call_graph_; 231 232 // The map of all HloValues in the module. We pass around pointers to the 233 // mapped HloValues, so the underlying container must keep them valid despite 234 // mutations touching other map entries. 235 std::unordered_map<HloValue::Id, HloValue> values_; 236 237 // A map from instruction to InstructionValueSet. 238 std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_; 239 240 // Values marked for deletion during construction. We don't delete them 241 // immediately because references to them may remain in ValueSets temporarily 242 // during propagation. After construction, these values are deleted. 243 std::vector<HloValue::Id> value_ids_to_delete_; 244 245 // A vector containing all HloValues sorted by HloValue::Id. 246 std::vector<HloValue*> values_vector_; 247 248 // The Id to use for the next HloValue. 249 HloValue::Id next_value_id_ = 0; 250 251 // Backend specific function that decides whether an instruction can share 252 // a buffer with its operand. 253 CanShareBuffer can_share_buffer_ = nullptr; 254 }; 255 256 } // namespace xla 257 258 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 259