1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Analysis for determining the possible set of values for all positions 17 // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped 18 // tracking values across computation boundaries. 19 20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 21 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 22 23 #include <iterator> 24 #include <memory> 25 #include <string> 26 #include <unordered_map> 27 #include <vector> 28 29 #include "absl/container/flat_hash_map.h" 30 #include "absl/container/flat_hash_set.h" 31 #include "absl/types/span.h" 32 #include "tensorflow/compiler/xla/service/call_graph.h" 33 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 34 #include "tensorflow/compiler/xla/service/hlo_module.h" 35 #include "tensorflow/compiler/xla/service/hlo_phi_graph.h" 36 #include "tensorflow/compiler/xla/service/hlo_value.h" 37 #include "tensorflow/compiler/xla/shape_util.h" 38 #include "tensorflow/compiler/xla/status.h" 39 #include "tensorflow/compiler/xla/statusor.h" 40 #include "tensorflow/compiler/xla/types.h" 41 #include "tensorflow/compiler/xla/xla_data.pb.h" 42 #include "tensorflow/core/platform/macros.h" 43 44 namespace xla { 45 46 // Analysis which identifies all HLO values and their uses in an HLO module. 47 class HloDataflowAnalysis { 48 public: 49 // Infrastructure for passing may-alias hints: HLO passes can populate the 50 // may-alias table. If an empty optional is returned, default rules are used. 51 // 52 // Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be 53 // overriden using backend-specific overrides. 54 // 55 // The first parameter of the function should be the instruction, the 56 // second parameter should be an operand of the instruction. The third 57 // parameter should be the output index of the instruction. 58 using CanShareBuffer = std::function<absl::optional<bool>( 59 const HloInstruction* instr, const HloInstruction* operand, 60 const ShapeIndex& user_index)>; 61 62 // Runs dataflow analysis on the given module. Parameters: 63 // 64 // ssa_form : If true then new values are defined at the merge points of 65 // kWhile instructions. Abusing nomenclature somewhat, we call these "phi 66 // values". The merge is formed by the init value and loop backedge. The 67 // SSA form is minimal in that a new phi value is defined only if the 68 // merge point is reachable by multiple different values. The SSA form is 69 // also in loop-closed form in that no values defined inside of a loop 70 // (while body) is used outside of the loop. Example use of this ssa_form 71 // mode is to reason about live range interference of buffers. 72 // 73 // If ssa_form is false, then merge points do not define new 74 // values. Rather, the HloValueSet for the merge point contains the union 75 // of the merged HloValues. 76 // 77 // bitcast_defines_value : If true then the Bitcast HLO instruction defines 78 // a new HLO value in the analysis. If false then Bitcast forwards the 79 // value of its operand. 80 static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run( 81 const HloModule& module, bool ssa_form = false, 82 bool bitcast_defines_value = false, 83 const CanShareBuffer& can_share_buffer = nullptr); 84 85 static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst); 86 87 // Returns true if 'instruction' defines an HLO value at the given shape index 88 // of its output. 89 bool ValueIsDefinedAt(const HloInstruction* instruction, 90 const ShapeIndex& index = {}) const; 91 92 // Returns the HloValue defined by 'instruction' at the given shape index of 93 // its output. 94 // 95 // Precondition: ValueIsDefinedAt is true for this instruction and index. 96 const HloValue& GetValueDefinedAt(const HloInstruction* instruction, 97 const ShapeIndex& index = {}) const; 98 HloValue& GetValueDefinedAt(const HloInstruction* instruction, 99 const ShapeIndex& index = {}); 100 101 // Returns the InstructionValueSet for the given instruction. 102 const InstructionValueSet& GetInstructionValueSet( 103 const HloInstruction* instruction) const; 104 InstructionValueSet& GetInstructionValueSet( 105 const HloInstruction* instruction); 106 107 // Returns all values that are contained in the output of this instruction in 108 // a flattened set. 109 HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const; 110 111 // Returns the HloValueSet for the given instruction at the given index or the 112 // given position. 113 const HloValueSet& GetValueSet(const HloInstruction* instruction, 114 const ShapeIndex& index = {}) const; 115 const HloValueSet& GetValueSet(const HloPosition& position) const; 116 HloValueSet& GetValueSet(const HloPosition& position); 117 HloValueSet& GetValueSet(const HloInstruction* instruction, 118 const ShapeIndex& index = {}); 119 120 // Returns the unique value in the HloValueSet at the given instruction and 121 // shape index. CHECKs if the value set does not contain a exactly one value. 122 const HloValue& GetUniqueValueAt(const HloInstruction* instruction, 123 const ShapeIndex& index = {}) const { 124 return GetValueSet(instruction, index).GetUniqueValue(); 125 } 126 HloValue& GetUniqueValueAt(const HloInstruction* instruction, 127 const ShapeIndex& index = {}) { 128 return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); 129 } 130 131 // Returns the HloValue with the given Id. 132 const HloValue& GetValue(HloValue::Id value_id) const; 133 HloValue& GetValue(HloValue::Id value_id); 134 135 // Returns the total number of HloValues. value_count()136 int64 value_count() const { return values_.size(); } 137 138 // Returns a vector of all HloValues stabily sorted by HloValue::Id. values()139 const std::vector<HloValue*>& values() const { return values_vector_; } 140 141 // Returns the call graph used for computing the dataflow. call_graph()142 const CallGraph& call_graph() const { return *call_graph_; } 143 144 string ToString() const; 145 146 // Returns true if 'user' cannot possibly use the buffer at 'index' in 147 // 'operand'. Returns false otherwise. 148 // 149 // 'operand' does not have to be an operand of 'user'. This can be the 150 // case with indirect uses. 151 bool DoesNotUseOperandBuffer(const HloInstruction* operand, 152 const ShapeIndex& index, 153 const HloInstruction* user) const; 154 155 // Returns true if 'user' (at 'user_index') can share a buffer with its 156 // operand 'operand' (at 'operand_index'). Returns false otherwise. 157 // 158 // REQUIRES: 'operand' is an operand of 'user'. 159 bool CanShareOperandBufferWithUser(HloInstruction* operand, 160 const ShapeIndex& operand_index, 161 HloInstruction* user, 162 const ShapeIndex& user_index) const; 163 module()164 const HloModule& module() const { return module_; } 165 166 // Returns true if the operation is an in-place operation and its operand 0 167 // must alias with the output. 168 static bool IsInPlaceOperation(HloOpcode opcode); 169 170 // Returns true if the operation is the start/done of an asynchronous 171 // operation, where the buffer used/produced by the op needs to stay alive 172 // until the asynchronous operation completes. 173 static bool IsAsynchronousOperationStart(HloOpcode opcode); 174 static bool IsAsynchronousOperationDone(HloOpcode opcode); 175 176 // Returns a vector consisting of the HloUse (operand number and shape index) 177 // and output shape index of the in-place operations within this HLO. 178 static std::vector<std::pair<HloUse, ShapeIndex>> GetInPlaceInputOutputPairs( 179 HloInstruction* instruction); 180 181 protected: 182 HloDataflowAnalysis(const HloModule& module, bool ssa_form, 183 bool bitcast_defines_value = false, 184 const CanShareBuffer& can_share_buffer = nullptr); 185 186 // 1. During value propagation (Propagate function), always create phi 187 // values once it see multiple inputs merging at the same point. It then 188 // records those phi values as well as their inputs in a phi graph. 189 // 190 // 2. Post value propagation, Dataflow analysis can then do certain 191 // optimization(OptimizePhiValues) on the phi graph to prune uncessary phi 192 // nodes. 193 // 194 // Note that this applies in SSA form, and Both of the functions are 195 // guaranteed to exit. 196 // 197 void OptimizePhiValues(); 198 199 // Returns a new HloValue defined at the given instruction and shape index. 200 HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, 201 bool is_phi); 202 203 // Marks the HloValue with the given ID for deletion. 204 void MarkValueForDeletion(HloValue::Id value_id); 205 206 // Deletes all HloValues marked for deletion. Should be called after 207 // propagation is complete. 208 void DeleteMarkedValues(); 209 210 // Constructs and initializes the InstructionValueSets of all instructions to 211 // contain exactly the HloValues defined by each instruction. These values can 212 // then propagated throughout the HLO graph by calling Propagate. 213 Status InitializeInstructionValueSets(); 214 215 // Updates the value set of the given instruction based on the values flowing 216 // into the instruction (operands and cross-computation dataflow). 217 bool UpdateInstructionValueSet(HloInstruction* instruction); 218 219 // Updates the value set for a particular instruction type. Returns whether 220 // the instruction value set changed. 221 bool UpdateBitcastValueSet(HloInstruction* bitcast); 222 bool UpdateCallValueSet(HloInstruction* call); 223 bool UpdateConditionalValueSet(HloInstruction* conditional); 224 bool UpdateCopyValueSet(HloInstruction* copy); 225 bool UpdateCustomCallValueSet(HloInstruction* custom_call); 226 bool UpdateDomainValueSet(HloInstruction* domain); 227 bool UpdateGetTupleElementValueSet(HloInstruction* gte); 228 bool UpdateParameterValueSet(HloInstruction* parameter); 229 bool UpdateCopyStartValueSet(HloInstruction* copy_start); 230 bool UpdateCopyDoneValueSet(HloInstruction* copy_done); 231 bool UpdateRecvDoneValueSet(HloInstruction* recv_done); 232 bool UpdateTupleSelectValueSet(HloInstruction* select); 233 bool UpdateSendValueSet(HloInstruction* send); 234 bool UpdateSetDimensionSizeValueSet(HloInstruction* set_dimension_size); 235 bool UpdateTupleValueSet(HloInstruction* tuple); 236 bool UpdateWhileValueSet(HloInstruction* xla_while); 237 bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); 238 bool UpdateAllGatherStartValueSet(HloInstruction* all_gather_start); 239 bool UpdateAllGatherDoneValueSet(HloInstruction* all_gather_done); 240 bool UpdateAllReduceStartValueSet(HloInstruction* all_reduce_start); 241 bool UpdateAllReduceDoneValueSet(HloInstruction* all_reduce_done); 242 bool UpdateCollectivePermuteStartValueSet( 243 HloInstruction* collective_permute_start); 244 bool UpdateCollectivePermuteDoneValueSet( 245 HloInstruction* collective_permute_done); 246 247 // Propagates the dataflow through the module. In particular, it propagates 248 // the HloValueSet from its defining instruction to the users of the 249 // instructions. 250 void Propagate(); 251 252 // Returns the result of the SSA Phi function applied to the given inputs at 253 // the given instruction. 254 bool Phi(HloInstruction* instruction, 255 absl::Span<const InstructionValueSet* const> inputs); 256 257 // Updates the positions of the HloValues in the output of the given 258 // instruction. This should be called after the instruction value set of 259 // 'instruction' has been changed. 'prev_value_set' must point to the previous 260 // state of the value set prior to the change. 'prev_value_set' may be null if 261 // this is the first time positions are being computed. The previous state is 262 // necessary to efficiently remove positions which have been eliminated due to 263 // changes in the instructions' InstructionValueSet. 264 void UpdatePositionsOfValuesAt( 265 HloInstruction* instruction, const InstructionValueSet& new_value_set, 266 const InstructionValueSet* prev_value_set = nullptr); 267 268 // Verifies various invariants of the dataflow analysis. 269 Status Verify() const; 270 271 const HloModule& module_; 272 const bool ssa_form_; 273 const bool bitcast_defines_value_; 274 275 std::unique_ptr<CallGraph> call_graph_; 276 277 // The map of all HloValues in the module. We pass around pointers to the 278 // mapped HloValues, so the underlying container must keep them valid despite 279 // mutations touching other map entries. 280 std::unordered_map<HloValue::Id, HloValue> values_; 281 282 // A map from instruction to InstructionValueSet. 283 std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_; 284 285 // Values marked for deletion during construction. We don't delete them 286 // immediately because references to them may remain in ValueSets temporarily 287 // during propagation. After construction, these values are deleted. 288 std::vector<HloValue::Id> value_ids_to_delete_; 289 290 // A vector containing all HloValues sorted by HloValue::Id. 291 std::vector<HloValue*> values_vector_; 292 293 // The Id to use for the next HloValue. 294 HloValue::Id next_value_id_ = 0; 295 296 // An explicit graph holding phi values and edges. 297 PhiGraph phi_graph_; 298 299 // Backend specific function that decides whether an instruction can share 300 // a buffer with its operand. 301 CanShareBuffer can_share_buffer_ = nullptr; 302 }; 303 304 } // namespace xla 305 306 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 307