1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ 18 19 #include <stddef.h> 20 #include <iosfwd> 21 #include <memory> 22 #include <set> 23 #include <string> 24 #include <vector> 25 26 #include "absl/container/flat_hash_map.h" 27 #include "absl/container/inlined_vector.h" 28 #include "absl/types/span.h" 29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 30 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 31 #include "tensorflow/compiler/xla/service/hlo_module.h" 32 #include "tensorflow/compiler/xla/service/logical_buffer.h" 33 #include "tensorflow/compiler/xla/service/logical_buffer_analysis.h" 34 #include "tensorflow/compiler/xla/shape_tree.h" 35 #include "tensorflow/compiler/xla/statusor.h" 36 #include "tensorflow/compiler/xla/types.h" 37 #include "tensorflow/compiler/xla/xla_data.pb.h" 38 #include "tensorflow/core/lib/core/status.h" 39 #include "tensorflow/core/lib/gtl/compactptrset.h" 40 #include "tensorflow/core/platform/macros.h" 41 #include "tensorflow/core/platform/types.h" 42 43 namespace xla { 44 45 // A class describing the source(s) of the Buffer(s) contained in the output of 46 // a particular HLO instruction. The structure of PointsToSet mirrors the 47 // structure of the instruction's shape, which may be an arbitrary tree (eg, a 48 // nested tuple). Each node in this tree corresponds to a single buffer in the 49 // instruction's output and contains the set of Buffers which might define 50 // the corresponding buffer. 51 class PointsToSet { 52 public: 53 // Construct our ShapeTree with a pointer rather than a reference to a Shape 54 // because this is very hot code, and copying (and then destroying) all these 55 // Shapes is slow. PointsToSet(const Shape * shape)56 explicit PointsToSet(const Shape* shape) : tree_(shape) {} 57 58 // Returns true if any points-to sets for any subshape element is not a 59 // singleton. 60 bool IsAmbiguous() const; 61 62 // Returns true if no LogicalBuffer appears in more than one points-to set of 63 // the shape nodes. 64 bool IsDistinct() const; 65 66 // Returns the total number of different LogicalBuffers contained in this 67 // object. This is equal to CreateFlattenedSet().size(). 68 size_t size() const; 69 70 // Creates a set containing the union of all LogicalBuffers contained in the 71 // PointsToSet. 72 using BufferSet = tensorflow::gtl::CompactPointerSet<const LogicalBuffer*>; 73 BufferSet CreateFlattenedSet() const; 74 75 // Returns true if the given buffer is in the points-to set at the given 76 // index. 77 bool ContainsBufferAtIndex(const LogicalBuffer& buffer, 78 const ShapeIndex& index) const; 79 80 // Returns true if the given buffer is in the points-to set at any index. 81 bool ContainsBuffer(const LogicalBuffer& buffer) const; 82 83 // Adds the given buffer to the points-to set at the given index. This is a 84 // nop if the buffer already is in the set at that index. 85 void AddPointedToBuffer(const LogicalBuffer& buffer, const ShapeIndex& index); 86 87 // For the subshape at the given index (where index is defined as in 88 // ShapeUtil::GetSubshape) this method returns the set of HLO instructions 89 // which may produce the tuple subshape at that index. For example, given: 90 // 91 // %tuple1 = tuple(...) 92 // %tuple2 = tuple(...) 93 // %select = select(%tuple1, %tuple2) 94 // %nested_tuple = tuple(%select, %tuple1) 95 // 96 // These are the values for tuple_sources() for the PointsToSet of 97 // %nested_tuple: 98 // 99 // tuple_sources({}) = {%nested_tuple} 100 // tuple_sources({0}) = {%tuple1, %tuple2} 101 // tuple_sources({1}) = {%tuple1} 102 // 103 // tuple_sources() at the index of an array shape (not a tuple) returns the 104 // empty set. The instructions in the set returned by tuple_sources 105 // necessarily are either Tuple instructions, constants, or parameters. 106 using SourceSet = tensorflow::gtl::CompactPointerSet<HloInstruction*>; 107 const SourceSet& tuple_sources(const ShapeIndex& index) const; 108 109 // Add a tuple source instruction for the given index. 110 void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple); 111 112 using BufferList = absl::InlinedVector<const LogicalBuffer*, 1>; 113 114 // Return the list of logical buffers for the subshape at index. element(const ShapeIndex & index)115 const BufferList& element(const ShapeIndex& index) const { 116 return tree_.element(index).buffers; 117 } mutable_element(const ShapeIndex & index)118 BufferList* mutable_element(const ShapeIndex& index) { 119 return &tree_.mutable_element(index)->buffers; 120 } 121 122 // Call fn(index, buflist) for every subshape index. 123 template <typename Fn> ForEachElement(const Fn & fn)124 void ForEachElement(const Fn& fn) const { 125 tree_.ForEachElement([&fn](const ShapeIndex& index, const Elem& elem) { 126 fn(index, elem.buffers); 127 }); 128 } 129 template <typename Fn> ForEachMutableElement(const Fn & fn)130 void ForEachMutableElement(const Fn& fn) { 131 tree_.ForEachMutableElement([&fn](const ShapeIndex& index, Elem* elem) { 132 fn(index, &elem->buffers); 133 }); 134 } 135 template <typename Fn> ForEachElementWithStatus(const Fn & fn)136 Status ForEachElementWithStatus(const Fn& fn) const { 137 return tree_.ForEachElementWithStatus( 138 [&fn](const ShapeIndex& index, const Elem& elem) { 139 return fn(index, elem.buffers); 140 }); 141 } 142 143 private: 144 struct Elem { 145 BufferList buffers; 146 SourceSet tuple_sources; 147 }; 148 ShapeTree<Elem> tree_; 149 150 // PointsToSet contains references (const LogicalBuffer*) to elements within 151 // TuplePointsToAnalysis, so disable copying. 152 TF_DISALLOW_COPY_AND_ASSIGN(PointsToSet); 153 }; 154 155 // This class describes a particular subshape in a computation (instruction and 156 // shape index) and the logical buffer which may be a source of the subshape 157 // value. 158 class BufferAlias { 159 public: BufferAlias(HloInstruction * instruction,const ShapeIndex & index)160 BufferAlias(HloInstruction* instruction, const ShapeIndex& index) 161 : instruction_(instruction), index_(index) {} 162 163 // Return the instruction/index of the subshape. instruction()164 HloInstruction* instruction() const { return instruction_; } index()165 const ShapeIndex& index() const { return index_; } 166 167 bool operator==(const BufferAlias& other) const { 168 return instruction_ == other.instruction_ && index_ == other.index_; 169 } 170 bool operator!=(const BufferAlias& other) const { return !(*this == other); } 171 172 string ToString() const; 173 174 private: 175 HloInstruction* instruction_; 176 ShapeIndex index_; 177 }; 178 179 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); 180 181 // DFS visitor that performs tuple points-to analysis. This analysis determines 182 // the potential sources of each buffer in each instruction's output. 183 class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { 184 public: 185 // Runs points-to analysis on 'module'. 186 static StatusOr<std::unique_ptr<TuplePointsToAnalysis>> Run( 187 const HloModule* module); 188 189 // Return the points-to set of an instruction. This describes the potential 190 // sources of each buffer in the instruction's output. 191 const PointsToSet& GetPointsToSet( 192 const HloInstruction* hlo_instruction) const; 193 194 // Returns the logical buffer with the given ID. 195 const LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const; 196 197 // Returns the buffer defined at the given instruction and index. An error is 198 // returned if no buffer is defined at that point. 199 StatusOr<const LogicalBuffer*> GetBufferDefinedAt( 200 const HloInstruction* instruction, const ShapeIndex& index) const; 201 202 // Return a (possibly empty) vector containing all BufferAliases of the given 203 // logical buffer The buffer alias set is the inverse of the points-to set. 204 // That is, LogicalBuffer B is in the points-to set of instruction I at index 205 // N iff instruction I, index N is a BufferAlias of B. 206 using BufferAliasVector = absl::InlinedVector<BufferAlias, 1>; 207 const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const; 208 209 // Returns the number of logical buffers in the module num_logical_buffers()210 LogicalBuffer::Id num_logical_buffers() const { 211 return logical_buffer_analysis_->num_logical_buffers(); 212 } 213 214 // Return a the logical buffer with id "id" in the module. Iteration 215 // over all logical buffers is usually done with something like: 216 // 217 // for (LogicalBuffer:Id id = 0; id < points_to.num_logical_buffers(); id++){ 218 // const auto& buffer = points_to.logical_buffer(id); 219 // ... do something with buffer ... 220 // } logical_buffer(LogicalBuffer::Id id)221 LogicalBuffer& logical_buffer(LogicalBuffer::Id id) const { 222 return logical_buffer_analysis_->GetBuffer(id); 223 } 224 225 // Returns a vector of buffers that the instruction produces. Most 226 // instructions produce a single buffer (the top-level buffer), some produce 227 // no buffers (eg bitcast), and some produce more than one buffer (eg, 228 // tuple-shaped parameters). 229 using BufferDefinitionVector = absl::InlinedVector<const LogicalBuffer*, 1>; 230 const BufferDefinitionVector& GetBuffersDefinedByInstruction( 231 const HloInstruction* instruction) const; 232 233 // Returns true if the given instruction defines a buffer at the given index. 234 bool InstructionDefinesBufferAtIndex(const HloInstruction* instruction, 235 const ShapeIndex& index) const; 236 237 // Returns an OK status if the given buffer is defined by instruction 238 // 'buffer.instruction()' at index 'buffer.index()' and if the given buffer 239 // matches the TuplePointsToAnalysis' LogicalBuffer with 'buffer.id'. Returns 240 // an FailedPrecondition error status otherwise. An example of a LogicalBuffer 241 // which is not defined is a tuple element in a Tuple instruction. In this 242 // case, the Tuple instruction does not define the LogicalBuffer, rather that 243 // index aliases one of its operands. 244 Status VerifyBuffer(const LogicalBuffer& buffer) const; 245 246 Status DefaultAction(HloInstruction* hlo_instruction) override; 247 Status HandleTuple(HloInstruction* tuple) override; 248 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 249 Status HandleBitcast(HloInstruction* bitcast) override; 250 Status HandleDomain(HloInstruction* domain) override; 251 Status HandleCopy(HloInstruction* copy) override; 252 Status HandleRecvDone(HloInstruction* recv_done) override; 253 Status HandleSend(HloInstruction* send) override; 254 Status HandleTupleSelect(HloInstruction* tuple_select) override; 255 Status HandleAddDependency(HloInstruction* add_dependency) override; 256 257 string ToString() const; 258 259 // Returns true if 'user' cannot possibly use the buffer at 'index' in 260 // 'operand'. Returns false otherwise. 261 // 262 // REQUIRES: 'operand' is an operand of 'user'. 263 bool DoesNotUseOperandBuffer(const HloInstruction* operand, 264 const ShapeIndex& index, 265 const HloInstruction* user) const; 266 267 // Returns true if 'user' (at 'user_index') can share a buffer with its 268 // operand 'operand' (at 'operand_index'). Returns false otherwise. 269 // 270 // REQUIRES: 'operand' is an operand of 'user'. 271 bool CanShareOperandBufferWithUser(HloInstruction* operand, 272 const ShapeIndex& operand_index, 273 HloInstruction* user, 274 const ShapeIndex& user_index) const; 275 276 private: TuplePointsToAnalysis(const HloModule * module,std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis)277 explicit TuplePointsToAnalysis( 278 const HloModule* module, 279 std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis) 280 : module_(module), 281 logical_buffer_analysis_(std::move(logical_buffer_analysis)) {} 282 283 // Perform the analysis. Should be called immediately after constructing the 284 // object and before calling GetPointsToSet. 285 Status Analyze(); 286 287 // Populates instruction-defined buffers and aliases for each instruction 288 // in 'instructions'. 289 Status PopulateDefinedBuffersAndAliases(const decltype( 290 std::declval<HloComputation>().instructions())& instructions); 291 292 // Creates an empty PointsToSet in the points_to_ map for the given 293 // instruction. 294 PointsToSet& CreateEmptyPointsToSet(const HloInstruction* instruction); 295 296 // Creates a PointsToSet in the points_to_ map for 'instruction' which is a 297 // copy of the existing PointsToSet for 'src'. 298 PointsToSet& CreateCopiedPointsToSet(const HloInstruction* instruction, 299 const HloInstruction* src); 300 301 // Adds the buffers defined by the given instruction to the given vector. 302 Status GatherBuffersDefinedByInstruction(const HloInstruction* instruction, 303 BufferDefinitionVector* buffers); 304 305 // Print points-to set for 'instruction' to 'output'. 306 void InstructionToString(const HloInstruction* instruction, 307 string* output) const; 308 309 // Information kept per instruction 310 struct PerInstruction { 311 std::unique_ptr<PointsToSet> points_to_set; 312 // Empircally, ~92% of instructions have 1 313 // instruction_defined_buffer, and 99% have 0 or 1 314 BufferDefinitionVector instruction_defined_buffers; 315 }; 316 PerInst(const HloInstruction * inst)317 const PerInstruction* PerInst(const HloInstruction* inst) const { 318 int id = inst->unique_id(); 319 DCHECK_GE(id, 0); 320 auto iter = per_instruction_.find(id); 321 if (iter == per_instruction_.end()) { 322 LOG(FATAL) << "Expected per-instruction information to already exist"; 323 } else { 324 return iter->second.get(); 325 } 326 } PerInst(const HloInstruction * inst)327 PerInstruction* PerInst(const HloInstruction* inst) { 328 int id = inst->unique_id(); 329 DCHECK_GE(id, 0); 330 auto iter = per_instruction_.find(id); 331 if (iter == per_instruction_.end()) { 332 return per_instruction_.emplace(id, absl::make_unique<PerInstruction>()) 333 .first->second.get(); 334 } else { 335 return iter->second.get(); 336 } 337 } 338 339 std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex( 340 HloInstruction* instruction, const ShapeIndex& index) const; 341 bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand, 342 const ShapeIndex& operand_index, 343 HloInstruction* fusion, 344 const int64 use_operand_index) const; 345 346 // The module this analysis is performed on. 347 const HloModule* module_; 348 349 // The logical buffers for this module. 350 const std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis_; 351 352 // A map from instruction->unique_id() to 353 absl::flat_hash_map<int, std::unique_ptr<PerInstruction>> per_instruction_; 354 355 // A map from LogicalBuffer->id() to alias information about that logical 356 // buffer 357 std::vector<BufferAliasVector> logical_buffer_aliases_; 358 359 TF_DISALLOW_COPY_AND_ASSIGN(TuplePointsToAnalysis); 360 }; 361 362 } // namespace xla 363 364 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ 365