1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ 18 19 #include <stddef.h> 20 21 #include <iosfwd> 22 #include <memory> 23 #include <set> 24 #include <string> 25 #include <vector> 26 27 #include "absl/container/flat_hash_map.h" 28 #include "absl/container/inlined_vector.h" 29 #include "absl/types/span.h" 30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 31 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 32 #include "tensorflow/compiler/xla/service/hlo_module.h" 33 #include "tensorflow/compiler/xla/service/logical_buffer.h" 34 #include "tensorflow/compiler/xla/service/logical_buffer_analysis.h" 35 #include "tensorflow/compiler/xla/shape_tree.h" 36 #include "tensorflow/compiler/xla/statusor.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/compiler/xla/xla_data.pb.h" 39 #include "tensorflow/core/lib/core/status.h" 40 #include "tensorflow/core/lib/gtl/compactptrset.h" 41 #include "tensorflow/core/platform/macros.h" 42 #include "tensorflow/core/platform/types.h" 43 44 namespace xla { 45 46 // A class describing the source(s) of the Buffer(s) contained in the output of 47 // a particular HLO instruction. The structure of PointsToSet mirrors the 48 // structure of the instruction's shape, which may be an arbitrary tree (eg, a 49 // nested tuple). Each node in this tree corresponds to a single buffer in the 50 // instruction's output and contains the set of Buffers which might define 51 // the corresponding buffer. 52 class PointsToSet { 53 public: 54 // Construct our ShapeTree with a pointer rather than a reference to a Shape 55 // because this is very hot code, and copying (and then destroying) all these 56 // Shapes is slow. PointsToSet(const Shape * shape)57 explicit PointsToSet(const Shape* shape) : tree_(shape) {} 58 59 // Returns true if any points-to sets for any subshape element is not a 60 // singleton. 61 bool IsAmbiguous() const; 62 63 // Returns true if no LogicalBuffer appears in more than one points-to set of 64 // the shape nodes. 65 bool IsDistinct() const; 66 67 // Returns the total number of different LogicalBuffers contained in this 68 // object. This is equal to CreateFlattenedSet().size(). 69 size_t size() const; 70 71 // Creates a set containing the union of all LogicalBuffers contained in the 72 // PointsToSet. 73 using BufferSet = tensorflow::gtl::CompactPointerSet<const LogicalBuffer*>; 74 BufferSet CreateFlattenedSet() const; 75 76 // Returns true if the given buffer is in the points-to set at the given 77 // index. 78 bool ContainsBufferAtIndex(const LogicalBuffer& buffer, 79 const ShapeIndex& index) const; 80 81 // Returns true if the given buffer is in the points-to set at any index. 82 bool ContainsBuffer(const LogicalBuffer& buffer) const; 83 84 // Adds the given buffer to the points-to set at the given index. This is a 85 // nop if the buffer already is in the set at that index. 86 void AddPointedToBuffer(const LogicalBuffer& buffer, const ShapeIndex& index); 87 88 // For the subshape at the given index (where index is defined as in 89 // ShapeUtil::GetSubshape) this method returns the set of HLO instructions 90 // which may produce the tuple subshape at that index. For example, given: 91 // 92 // %tuple1 = tuple(...) 93 // %tuple2 = tuple(...) 94 // %select = select(%tuple1, %tuple2) 95 // %nested_tuple = tuple(%select, %tuple1) 96 // 97 // These are the values for tuple_sources() for the PointsToSet of 98 // %nested_tuple: 99 // 100 // tuple_sources({}) = {%nested_tuple} 101 // tuple_sources({0}) = {%tuple1, %tuple2} 102 // tuple_sources({1}) = {%tuple1} 103 // 104 // tuple_sources() at the index of an array shape (not a tuple) returns the 105 // empty set. The instructions in the set returned by tuple_sources 106 // necessarily are either Tuple instructions, constants, or parameters. 107 using SourceSet = tensorflow::gtl::CompactPointerSet<HloInstruction*>; 108 const SourceSet& tuple_sources(const ShapeIndex& index) const; 109 110 // Add a tuple source instruction for the given index. 111 void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple); 112 113 using BufferList = absl::InlinedVector<const LogicalBuffer*, 1>; 114 115 // Return the list of logical buffers for the subshape at index. element(const ShapeIndex & index)116 const BufferList& element(const ShapeIndex& index) const { 117 return tree_.element(index).buffers; 118 } mutable_element(const ShapeIndex & index)119 BufferList* mutable_element(const ShapeIndex& index) { 120 return &tree_.mutable_element(index)->buffers; 121 } 122 123 // Call fn(index, buflist) for every subshape index. 124 template <typename Fn> ForEachElement(const Fn & fn)125 void ForEachElement(const Fn& fn) const { 126 tree_.ForEachElement([&fn](const ShapeIndex& index, const Elem& elem) { 127 fn(index, elem.buffers); 128 }); 129 } 130 template <typename Fn> ForEachMutableElement(const Fn & fn)131 void ForEachMutableElement(const Fn& fn) { 132 tree_.ForEachMutableElement([&fn](const ShapeIndex& index, Elem* elem) { 133 fn(index, &elem->buffers); 134 }); 135 } 136 template <typename Fn> ForEachElementWithStatus(const Fn & fn)137 Status ForEachElementWithStatus(const Fn& fn) const { 138 return tree_.ForEachElementWithStatus( 139 [&fn](const ShapeIndex& index, const Elem& elem) { 140 return fn(index, elem.buffers); 141 }); 142 } 143 144 private: 145 struct Elem { 146 BufferList buffers; 147 SourceSet tuple_sources; 148 }; 149 ShapeTree<Elem> tree_; 150 151 // PointsToSet contains references (const LogicalBuffer*) to elements within 152 // TuplePointsToAnalysis, so disable copying. 153 TF_DISALLOW_COPY_AND_ASSIGN(PointsToSet); 154 }; 155 156 // This class describes a particular subshape in a computation (instruction and 157 // shape index) and the logical buffer which may be a source of the subshape 158 // value. 159 class BufferAlias { 160 public: BufferAlias(HloInstruction * instruction,const ShapeIndex & index)161 BufferAlias(HloInstruction* instruction, const ShapeIndex& index) 162 : instruction_(instruction), index_(index) {} 163 164 // Return the instruction/index of the subshape. instruction()165 HloInstruction* instruction() const { return instruction_; } index()166 const ShapeIndex& index() const { return index_; } 167 168 bool operator==(const BufferAlias& other) const { 169 return instruction_ == other.instruction_ && index_ == other.index_; 170 } 171 bool operator!=(const BufferAlias& other) const { return !(*this == other); } 172 173 string ToString() const; 174 175 private: 176 HloInstruction* instruction_; 177 ShapeIndex index_; 178 }; 179 180 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); 181 182 // DFS visitor that performs tuple points-to analysis. This analysis determines 183 // the potential sources of each buffer in each instruction's output. 184 class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { 185 public: 186 // Runs points-to analysis on 'module'. 187 static StatusOr<std::unique_ptr<TuplePointsToAnalysis>> Run( 188 const HloModule* module); 189 190 // Return the points-to set of an instruction. This describes the potential 191 // sources of each buffer in the instruction's output. 192 const PointsToSet& GetPointsToSet( 193 const HloInstruction* hlo_instruction) const; 194 195 // Returns the logical buffer with the given ID. 196 const LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const; 197 198 // Returns the buffer defined at the given instruction and index. An error is 199 // returned if no buffer is defined at that point. 200 StatusOr<const LogicalBuffer*> GetBufferDefinedAt( 201 const HloInstruction* instruction, const ShapeIndex& index) const; 202 203 // Return a (possibly empty) vector containing all BufferAliases of the given 204 // logical buffer The buffer alias set is the inverse of the points-to set. 205 // That is, LogicalBuffer B is in the points-to set of instruction I at index 206 // N iff instruction I, index N is a BufferAlias of B. 207 using BufferAliasVector = absl::InlinedVector<BufferAlias, 1>; 208 const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const; 209 210 // Returns the number of logical buffers in the module num_logical_buffers()211 LogicalBuffer::Id num_logical_buffers() const { 212 return logical_buffer_analysis_->num_logical_buffers(); 213 } 214 215 // Return a the logical buffer with id "id" in the module. Iteration 216 // over all logical buffers is usually done with something like: 217 // 218 // for (LogicalBuffer:Id id = 0; id < points_to.num_logical_buffers(); id++){ 219 // const auto& buffer = points_to.logical_buffer(id); 220 // ... do something with buffer ... 221 // } logical_buffer(LogicalBuffer::Id id)222 LogicalBuffer& logical_buffer(LogicalBuffer::Id id) const { 223 return logical_buffer_analysis_->GetBuffer(id); 224 } 225 226 // Returns a vector of buffers that the instruction produces. Most 227 // instructions produce a single buffer (the top-level buffer), some produce 228 // no buffers (eg bitcast), and some produce more than one buffer (eg, 229 // tuple-shaped parameters). 230 using BufferDefinitionVector = absl::InlinedVector<const LogicalBuffer*, 1>; 231 const BufferDefinitionVector& GetBuffersDefinedByInstruction( 232 const HloInstruction* instruction) const; 233 234 // Returns true if the given instruction defines a buffer at the given index. 235 bool InstructionDefinesBufferAtIndex(const HloInstruction* instruction, 236 const ShapeIndex& index) const; 237 238 // Returns an OK status if the given buffer is defined by instruction 239 // 'buffer.instruction()' at index 'buffer.index()' and if the given buffer 240 // matches the TuplePointsToAnalysis' LogicalBuffer with 'buffer.id'. Returns 241 // an FailedPrecondition error status otherwise. An example of a LogicalBuffer 242 // which is not defined is a tuple element in a Tuple instruction. In this 243 // case, the Tuple instruction does not define the LogicalBuffer, rather that 244 // index aliases one of its operands. 245 Status VerifyBuffer(const LogicalBuffer& buffer) const; 246 247 Status DefaultAction(HloInstruction* hlo_instruction) override; 248 Status HandleTuple(HloInstruction* tuple) override; 249 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 250 Status HandleBitcast(HloInstruction* bitcast) override; 251 Status HandleDomain(HloInstruction* domain) override; 252 Status HandleCopy(HloInstruction* copy) override; 253 Status HandleCopyStart(HloInstruction* copy_start) override; 254 Status HandleCopyDone(HloInstruction* copy_done) override; 255 Status HandleRecvDone(HloInstruction* recv_done) override; 256 Status HandleSend(HloInstruction* send) override; 257 Status HandleTupleSelect(HloInstruction* tuple_select) override; 258 Status HandleAddDependency(HloInstruction* add_dependency) override; 259 Status HandleCustomCall(HloInstruction* custom_call) override; 260 261 string ToString() const; 262 263 // Returns true if 'user' cannot possibly use the buffer at 'index' in 264 // 'operand'. Returns false otherwise. 265 // 266 // REQUIRES: 'operand' is an operand of 'user'. 267 bool DoesNotUseOperandBuffer(const HloInstruction* operand, 268 const ShapeIndex& index, 269 const HloInstruction* user) const; 270 271 private: TuplePointsToAnalysis(const HloModule * module,std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis)272 explicit TuplePointsToAnalysis( 273 const HloModule* module, 274 std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis) 275 : module_(module), 276 logical_buffer_analysis_(std::move(logical_buffer_analysis)) {} 277 278 // Perform the analysis. Should be called immediately after constructing the 279 // object and before calling GetPointsToSet. 280 Status Analyze(); 281 282 // Populates instruction-defined buffers and aliases for each instruction 283 // in 'instructions'. 284 Status PopulateDefinedBuffersAndAliases(const decltype( 285 std::declval<HloComputation>().instructions())& instructions); 286 287 // Creates an empty PointsToSet in the points_to_ map for the given 288 // instruction. 289 PointsToSet& CreateEmptyPointsToSet(const HloInstruction* instruction); 290 291 // Creates a PointsToSet in the points_to_ map for 'instruction' which is a 292 // copy of the existing PointsToSet for 'src'. 293 PointsToSet& CreateCopiedPointsToSet(const HloInstruction* instruction, 294 const HloInstruction* src); 295 296 // Adds the buffers defined by the given instruction to the given vector. 297 Status GatherBuffersDefinedByInstruction(const HloInstruction* instruction, 298 BufferDefinitionVector* buffers); 299 300 // Print points-to set for 'instruction' to 'output'. 301 void InstructionToString(const HloInstruction* instruction, 302 string* output) const; 303 304 // Information kept per instruction 305 struct PerInstruction { 306 std::unique_ptr<PointsToSet> points_to_set; 307 // Empirically, ~92% of instructions have 1 308 // instruction_defined_buffer, and 99% have 0 or 1 309 BufferDefinitionVector instruction_defined_buffers; 310 }; 311 PerInst(const HloInstruction * inst)312 const PerInstruction* PerInst(const HloInstruction* inst) const { 313 int id = inst->unique_id(); 314 DCHECK_GE(id, 0); 315 auto iter = per_instruction_.find(id); 316 if (iter == per_instruction_.end()) { 317 LOG(FATAL) << "Expected per-instruction information to already exist"; 318 } else { 319 return iter->second.get(); 320 } 321 } PerInst(const HloInstruction * inst)322 PerInstruction* PerInst(const HloInstruction* inst) { 323 int id = inst->unique_id(); 324 DCHECK_GE(id, 0); 325 auto iter = per_instruction_.find(id); 326 if (iter == per_instruction_.end()) { 327 return per_instruction_.emplace(id, absl::make_unique<PerInstruction>()) 328 .first->second.get(); 329 } else { 330 return iter->second.get(); 331 } 332 } 333 334 std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex( 335 HloInstruction* instruction, const ShapeIndex& index) const; 336 bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand, 337 const ShapeIndex& operand_index, 338 HloInstruction* fusion, 339 const int64 use_operand_index) const; 340 341 // The module this analysis is performed on. 342 const HloModule* module_; 343 344 // The logical buffers for this module. 345 const std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis_; 346 347 // A map from instruction->unique_id() to 348 absl::flat_hash_map<int, std::unique_ptr<PerInstruction>> per_instruction_; 349 350 // A map from LogicalBuffer->id() to alias information about that logical 351 // buffer 352 std::vector<BufferAliasVector> logical_buffer_aliases_; 353 354 TF_DISALLOW_COPY_AND_ASSIGN(TuplePointsToAnalysis); 355 }; 356 357 } // namespace xla 358 359 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ 360