1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ 18 19 #include <stddef.h> 20 21 #include <iosfwd> 22 #include <memory> 23 #include <set> 24 #include <string> 25 #include <vector> 26 27 #include "absl/container/flat_hash_map.h" 28 #include "absl/container/inlined_vector.h" 29 #include "absl/types/span.h" 30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 31 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 32 #include "tensorflow/compiler/xla/service/hlo_module.h" 33 #include "tensorflow/compiler/xla/service/logical_buffer.h" 34 #include "tensorflow/compiler/xla/service/logical_buffer_analysis.h" 35 #include "tensorflow/compiler/xla/shape_tree.h" 36 #include "tensorflow/compiler/xla/statusor.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/compiler/xla/xla_data.pb.h" 39 #include "tensorflow/core/lib/core/status.h" 40 #include "tensorflow/core/lib/gtl/compactptrset.h" 41 42 namespace xla { 43 44 // A class describing the source(s) of the Buffer(s) contained in the output of 45 // a particular HLO instruction. The structure of PointsToSet mirrors the 46 // structure of the instruction's shape, which may be an arbitrary tree (eg, a 47 // nested tuple). Each node in this tree corresponds to a single buffer in the 48 // instruction's output and contains the set of Buffers which might define 49 // the corresponding buffer. 50 class PointsToSet { 51 public: 52 // Construct our ShapeTree with a pointer rather than a reference to a Shape 53 // because this is very hot code, and copying (and then destroying) all these 54 // Shapes is slow. PointsToSet(const Shape * shape)55 explicit PointsToSet(const Shape* shape) : tree_(shape) {} 56 57 // Returns true if any points-to sets for any subshape element is not a 58 // singleton. 59 bool IsAmbiguous() const; 60 61 // Returns true if no LogicalBuffer appears in more than one points-to set of 62 // the shape nodes. 63 bool IsDistinct() const; 64 65 // Returns the total number of different LogicalBuffers contained in this 66 // object. This is equal to CreateFlattenedSet().size(). 67 size_t size() const; 68 69 // Creates a set containing the union of all LogicalBuffers contained in the 70 // PointsToSet. 71 using BufferSet = tensorflow::gtl::CompactPointerSet<const LogicalBuffer*>; 72 BufferSet CreateFlattenedSet() const; 73 74 // Returns true if the given buffer is in the points-to set at the given 75 // index. 76 bool ContainsBufferAtIndex(const LogicalBuffer& buffer, 77 const ShapeIndex& index) const; 78 79 // Returns true if the given buffer is in the points-to set at any index. 80 bool ContainsBuffer(const LogicalBuffer& buffer) const; 81 82 // Adds the given buffer to the points-to set at the given index. This is a 83 // nop if the buffer already is in the set at that index. 84 void AddPointedToBuffer(const LogicalBuffer& buffer, const ShapeIndex& index); 85 86 // For the subshape at the given index (where index is defined as in 87 // ShapeUtil::GetSubshape) this method returns the set of HLO instructions 88 // which may produce the tuple subshape at that index. For example, given: 89 // 90 // %tuple1 = tuple(...) 91 // %tuple2 = tuple(...) 92 // %select = select(%tuple1, %tuple2) 93 // %nested_tuple = tuple(%select, %tuple1) 94 // 95 // These are the values for tuple_sources() for the PointsToSet of 96 // %nested_tuple: 97 // 98 // tuple_sources({}) = {%nested_tuple} 99 // tuple_sources({0}) = {%tuple1, %tuple2} 100 // tuple_sources({1}) = {%tuple1} 101 // 102 // tuple_sources() at the index of an array shape (not a tuple) returns the 103 // empty set. The instructions in the set returned by tuple_sources 104 // necessarily are either Tuple instructions, constants, or parameters. 105 using SourceSet = tensorflow::gtl::CompactPointerSet<HloInstruction*>; 106 const SourceSet& tuple_sources(const ShapeIndex& index) const; 107 108 // Add a tuple source instruction for the given index. 109 void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple); 110 111 using BufferList = absl::InlinedVector<const LogicalBuffer*, 1>; 112 113 // Return the list of logical buffers for the subshape at index. element(const ShapeIndex & index)114 const BufferList& element(const ShapeIndex& index) const { 115 return tree_.element(index).buffers; 116 } mutable_element(const ShapeIndex & index)117 BufferList* mutable_element(const ShapeIndex& index) { 118 return &tree_.mutable_element(index)->buffers; 119 } 120 121 // Call fn(index, buflist) for every subshape index. 122 template <typename Fn> ForEachElement(const Fn & fn)123 void ForEachElement(const Fn& fn) const { 124 tree_.ForEachElement([&fn](const ShapeIndex& index, const Elem& elem) { 125 fn(index, elem.buffers); 126 }); 127 } 128 template <typename Fn> ForEachMutableElement(const Fn & fn)129 void ForEachMutableElement(const Fn& fn) { 130 tree_.ForEachMutableElement([&fn](const ShapeIndex& index, Elem* elem) { 131 fn(index, &elem->buffers); 132 }); 133 } 134 template <typename Fn> ForEachElementWithStatus(const Fn & fn)135 Status ForEachElementWithStatus(const Fn& fn) const { 136 return tree_.ForEachElementWithStatus( 137 [&fn](const ShapeIndex& index, const Elem& elem) { 138 return fn(index, elem.buffers); 139 }); 140 } 141 142 private: 143 struct Elem { 144 BufferList buffers; 145 SourceSet tuple_sources; 146 }; 147 ShapeTree<Elem> tree_; 148 149 // PointsToSet contains references (const LogicalBuffer*) to elements within 150 // TuplePointsToAnalysis, so disable copying. 151 PointsToSet(const PointsToSet&) = delete; 152 PointsToSet& operator=(const PointsToSet&) = delete; 153 }; 154 155 // This class describes a particular subshape in a computation (instruction and 156 // shape index) and the logical buffer which may be a source of the subshape 157 // value. 158 class BufferAlias { 159 public: BufferAlias(HloInstruction * instruction,const ShapeIndex & index)160 BufferAlias(HloInstruction* instruction, const ShapeIndex& index) 161 : instruction_(instruction), index_(index) {} 162 163 // Return the instruction/index of the subshape. instruction()164 HloInstruction* instruction() const { return instruction_; } index()165 const ShapeIndex& index() const { return index_; } 166 167 bool operator==(const BufferAlias& other) const { 168 return instruction_ == other.instruction_ && index_ == other.index_; 169 } 170 bool operator!=(const BufferAlias& other) const { return !(*this == other); } 171 172 std::string ToString() const; 173 174 private: 175 HloInstruction* instruction_; 176 ShapeIndex index_; 177 }; 178 179 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); 180 181 // DFS visitor that performs tuple points-to analysis. This analysis determines 182 // the potential sources of each buffer in each instruction's output. 183 class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { 184 public: 185 // Runs points-to analysis on 'module'. 186 static StatusOr<std::unique_ptr<TuplePointsToAnalysis>> Run( 187 const HloModule* module); 188 189 // Return the points-to set of an instruction. This describes the potential 190 // sources of each buffer in the instruction's output. 191 const PointsToSet& GetPointsToSet( 192 const HloInstruction* hlo_instruction) const; 193 194 // Returns the logical buffer with the given ID. 195 const LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const; 196 197 // Returns the buffer defined at the given instruction and index. An error is 198 // returned if no buffer is defined at that point. 199 StatusOr<const LogicalBuffer*> GetBufferDefinedAt( 200 const HloInstruction* instruction, const ShapeIndex& index) const; 201 202 // Return a (possibly empty) vector containing all BufferAliases of the given 203 // logical buffer The buffer alias set is the inverse of the points-to set. 204 // That is, LogicalBuffer B is in the points-to set of instruction I at index 205 // N iff instruction I, index N is a BufferAlias of B. 206 using BufferAliasVector = absl::InlinedVector<BufferAlias, 1>; 207 const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const; 208 209 // Returns the number of logical buffers in the module num_logical_buffers()210 LogicalBuffer::Id num_logical_buffers() const { 211 return logical_buffer_analysis_->num_logical_buffers(); 212 } 213 214 // Return a the logical buffer with id "id" in the module. Iteration 215 // over all logical buffers is usually done with something like: 216 // 217 // for (LogicalBuffer:Id id = 0; id < points_to.num_logical_buffers(); id++){ 218 // const auto& buffer = points_to.logical_buffer(id); 219 // ... do something with buffer ... 220 // } logical_buffer(LogicalBuffer::Id id)221 LogicalBuffer& logical_buffer(LogicalBuffer::Id id) const { 222 return logical_buffer_analysis_->GetBuffer(id); 223 } 224 225 // Returns a vector of buffers that the instruction produces. Most 226 // instructions produce a single buffer (the top-level buffer), some produce 227 // no buffers (eg bitcast), and some produce more than one buffer (eg, 228 // tuple-shaped parameters). 229 using BufferDefinitionVector = absl::InlinedVector<const LogicalBuffer*, 1>; 230 const BufferDefinitionVector& GetBuffersDefinedByInstruction( 231 const HloInstruction* instruction) const; 232 233 // Returns true if the given instruction defines a buffer at the given index. 234 bool InstructionDefinesBufferAtIndex(const HloInstruction* instruction, 235 const ShapeIndex& index) const; 236 237 // Returns an OK status if the given buffer is defined by instruction 238 // 'buffer.instruction()' at index 'buffer.index()' and if the given buffer 239 // matches the TuplePointsToAnalysis' LogicalBuffer with 'buffer.id'. Returns 240 // an FailedPrecondition error status otherwise. An example of a LogicalBuffer 241 // which is not defined is a tuple element in a Tuple instruction. In this 242 // case, the Tuple instruction does not define the LogicalBuffer, rather that 243 // index aliases one of its operands. 244 Status VerifyBuffer(const LogicalBuffer& buffer) const; 245 246 Status DefaultAction(HloInstruction* hlo_instruction) override; 247 Status HandleTuple(HloInstruction* tuple) override; 248 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 249 Status HandleAsyncStart(HloInstruction* async_start) override; 250 Status HandleAsyncUpdate(HloInstruction* async_update) override; 251 Status HandleAsyncDone(HloInstruction* async_done) override; 252 Status HandleBitcast(HloInstruction* bitcast) override; 253 Status HandleDomain(HloInstruction* domain) override; 254 Status HandleCopy(HloInstruction* copy) override; 255 Status HandleCopyStart(HloInstruction* copy_start) override; 256 Status HandleCopyDone(HloInstruction* copy_done) override; 257 Status HandleRecvDone(HloInstruction* recv_done) override; 258 Status HandleSend(HloInstruction* send) override; 259 Status HandleAddDependency(HloInstruction* add_dependency) override; 260 Status HandleCustomCall(HloInstruction* custom_call) override; 261 Status HandleOptimizationBarrier(HloInstruction* barrier) override; 262 263 std::string ToString() const; 264 265 // Returns true if 'user' cannot possibly use the buffer at 'index' in 266 // 'operand'. Returns false otherwise. 267 // 268 // REQUIRES: 'operand' is an operand of 'user'. 269 bool DoesNotUseOperandBuffer(const HloInstruction* operand, 270 const ShapeIndex& index, 271 const HloInstruction* user) const; 272 273 private: TuplePointsToAnalysis(const HloModule * module,std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis)274 explicit TuplePointsToAnalysis( 275 const HloModule* module, 276 std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis) 277 : module_(module), 278 logical_buffer_analysis_(std::move(logical_buffer_analysis)) {} 279 280 // Perform the analysis. Should be called immediately after constructing the 281 // object and before calling GetPointsToSet. 282 Status Analyze(); 283 284 // Populates instruction-defined buffers and aliases for each instruction 285 // in 'instructions'. 286 Status PopulateDefinedBuffersAndAliases( 287 const decltype(std::declval<HloComputation>() 288 .instructions())& instructions); 289 290 // Creates an empty PointsToSet in the points_to_ map for the given 291 // instruction. 292 PointsToSet& CreateEmptyPointsToSet(const HloInstruction* instruction); 293 294 // Creates a PointsToSet in the points_to_ map for 'instruction' which is a 295 // copy of the existing PointsToSet for 'src'. 296 PointsToSet& CreateCopiedPointsToSet(const HloInstruction* instruction, 297 const HloInstruction* src); 298 299 // Adds the buffers defined by the given instruction to the given vector. 300 Status GatherBuffersDefinedByInstruction(const HloInstruction* instruction, 301 BufferDefinitionVector* buffers); 302 303 // Print points-to set for 'instruction' to 'output'. 304 void InstructionToString(const HloInstruction* instruction, 305 std::string* output) const; 306 307 // Information kept per instruction 308 struct PerInstruction { 309 std::unique_ptr<PointsToSet> points_to_set; 310 // Empirically, ~92% of instructions have 1 311 // instruction_defined_buffer, and 99% have 0 or 1 312 BufferDefinitionVector instruction_defined_buffers; 313 }; 314 PerInst(const HloInstruction * inst)315 const PerInstruction* PerInst(const HloInstruction* inst) const { 316 int id = inst->unique_id(); 317 DCHECK_GE(id, 0); 318 auto iter = per_instruction_.find(id); 319 if (iter == per_instruction_.end()) { 320 LOG(FATAL) << "Expected per-instruction information to already exist"; 321 } else { 322 return iter->second.get(); 323 } 324 } PerInst(const HloInstruction * inst)325 PerInstruction* PerInst(const HloInstruction* inst) { 326 int id = inst->unique_id(); 327 DCHECK_GE(id, 0); 328 auto iter = per_instruction_.find(id); 329 if (iter == per_instruction_.end()) { 330 return per_instruction_.emplace(id, std::make_unique<PerInstruction>()) 331 .first->second.get(); 332 } else { 333 return iter->second.get(); 334 } 335 } 336 337 std::vector<std::pair<HloInstruction*, int64_t>> 338 GetAllUsesOfInstructionAtIndex(HloInstruction* instruction, 339 const ShapeIndex& index) const; 340 bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand, 341 const ShapeIndex& operand_index, 342 HloInstruction* fusion, 343 const int64_t use_operand_index) const; 344 345 // The module this analysis is performed on. 346 const HloModule* module_; 347 348 // The logical buffers for this module. 349 const std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis_; 350 351 // A map from instruction->unique_id() to 352 absl::flat_hash_map<int, std::unique_ptr<PerInstruction>> per_instruction_; 353 354 // A map from LogicalBuffer->id() to alias information about that logical 355 // buffer 356 std::vector<BufferAliasVector> logical_buffer_aliases_; 357 358 TuplePointsToAnalysis(const TuplePointsToAnalysis&) = delete; 359 TuplePointsToAnalysis& operator=(const TuplePointsToAnalysis&) = delete; 360 }; 361 362 } // namespace xla 363 364 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ 365