• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
18 
19 #include <stddef.h>
20 
21 #include <iosfwd>
22 #include <memory>
23 #include <set>
24 #include <string>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/logical_buffer.h"
34 #include "tensorflow/compiler/xla/service/logical_buffer_analysis.h"
35 #include "tensorflow/compiler/xla/shape_tree.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/gtl/compactptrset.h"
41 
42 namespace xla {
43 
44 // A class describing the source(s) of the Buffer(s) contained in the output of
45 // a particular HLO instruction. The structure of PointsToSet mirrors the
46 // structure of the instruction's shape, which may be an arbitrary tree (eg, a
47 // nested tuple). Each node in this tree corresponds to a single buffer in the
48 // instruction's output and contains the set of Buffers which might define
49 // the corresponding buffer.
50 class PointsToSet {
51  public:
52   // Construct our ShapeTree with a pointer rather than a reference to a Shape
53   // because this is very hot code, and copying (and then destroying) all these
54   // Shapes is slow.
PointsToSet(const Shape * shape)55   explicit PointsToSet(const Shape* shape) : tree_(shape) {}
56 
57   // Returns true if any points-to sets for any subshape element is not a
58   // singleton.
59   bool IsAmbiguous() const;
60 
61   // Returns true if no LogicalBuffer appears in more than one points-to set of
62   // the shape nodes.
63   bool IsDistinct() const;
64 
65   // Returns the total number of different LogicalBuffers contained in this
66   // object. This is equal to CreateFlattenedSet().size().
67   size_t size() const;
68 
69   // Creates a set containing the union of all LogicalBuffers contained in the
70   // PointsToSet.
71   using BufferSet = tensorflow::gtl::CompactPointerSet<const LogicalBuffer*>;
72   BufferSet CreateFlattenedSet() const;
73 
74   // Returns true if the given buffer is in the points-to set at the given
75   // index.
76   bool ContainsBufferAtIndex(const LogicalBuffer& buffer,
77                              const ShapeIndex& index) const;
78 
79   // Returns true if the given buffer is in the points-to set at any index.
80   bool ContainsBuffer(const LogicalBuffer& buffer) const;
81 
82   // Adds the given buffer to the points-to set at the given index. This is a
83   // nop if the buffer already is in the set at that index.
84   void AddPointedToBuffer(const LogicalBuffer& buffer, const ShapeIndex& index);
85 
86   // For the subshape at the given index (where index is defined as in
87   // ShapeUtil::GetSubshape) this method returns the set of HLO instructions
88   // which may produce the tuple subshape at that index. For example, given:
89   //
90   // %tuple1 = tuple(...)
91   // %tuple2 = tuple(...)
92   // %select = select(%tuple1, %tuple2)
93   // %nested_tuple = tuple(%select, %tuple1)
94   //
95   // These are the values for tuple_sources() for the PointsToSet of
96   // %nested_tuple:
97   //
98   // tuple_sources({}) = {%nested_tuple}
99   // tuple_sources({0}) = {%tuple1, %tuple2}
100   // tuple_sources({1}) = {%tuple1}
101   //
102   // tuple_sources() at the index of an array shape (not a tuple) returns the
103   // empty set. The instructions in the set returned by tuple_sources
104   // necessarily are either Tuple instructions, constants, or parameters.
105   using SourceSet = tensorflow::gtl::CompactPointerSet<HloInstruction*>;
106   const SourceSet& tuple_sources(const ShapeIndex& index) const;
107 
108   // Add a tuple source instruction for the given index.
109   void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple);
110 
111   using BufferList = absl::InlinedVector<const LogicalBuffer*, 1>;
112 
113   // Return the list of logical buffers for the subshape at index.
element(const ShapeIndex & index)114   const BufferList& element(const ShapeIndex& index) const {
115     return tree_.element(index).buffers;
116   }
mutable_element(const ShapeIndex & index)117   BufferList* mutable_element(const ShapeIndex& index) {
118     return &tree_.mutable_element(index)->buffers;
119   }
120 
121   // Call fn(index, buflist) for every subshape index.
122   template <typename Fn>
ForEachElement(const Fn & fn)123   void ForEachElement(const Fn& fn) const {
124     tree_.ForEachElement([&fn](const ShapeIndex& index, const Elem& elem) {
125       fn(index, elem.buffers);
126     });
127   }
128   template <typename Fn>
ForEachMutableElement(const Fn & fn)129   void ForEachMutableElement(const Fn& fn) {
130     tree_.ForEachMutableElement([&fn](const ShapeIndex& index, Elem* elem) {
131       fn(index, &elem->buffers);
132     });
133   }
134   template <typename Fn>
ForEachElementWithStatus(const Fn & fn)135   Status ForEachElementWithStatus(const Fn& fn) const {
136     return tree_.ForEachElementWithStatus(
137         [&fn](const ShapeIndex& index, const Elem& elem) {
138           return fn(index, elem.buffers);
139         });
140   }
141 
142  private:
143   struct Elem {
144     BufferList buffers;
145     SourceSet tuple_sources;
146   };
147   ShapeTree<Elem> tree_;
148 
149   // PointsToSet contains references (const LogicalBuffer*) to elements within
150   // TuplePointsToAnalysis, so disable copying.
151   PointsToSet(const PointsToSet&) = delete;
152   PointsToSet& operator=(const PointsToSet&) = delete;
153 };
154 
155 // This class describes a particular subshape in a computation (instruction and
156 // shape index) and the logical buffer which may be a source of the subshape
157 // value.
158 class BufferAlias {
159  public:
BufferAlias(HloInstruction * instruction,const ShapeIndex & index)160   BufferAlias(HloInstruction* instruction, const ShapeIndex& index)
161       : instruction_(instruction), index_(index) {}
162 
163   // Return the instruction/index of the subshape.
instruction()164   HloInstruction* instruction() const { return instruction_; }
index()165   const ShapeIndex& index() const { return index_; }
166 
167   bool operator==(const BufferAlias& other) const {
168     return instruction_ == other.instruction_ && index_ == other.index_;
169   }
170   bool operator!=(const BufferAlias& other) const { return !(*this == other); }
171 
172   std::string ToString() const;
173 
174  private:
175   HloInstruction* instruction_;
176   ShapeIndex index_;
177 };
178 
179 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias);
180 
181 // DFS visitor that performs tuple points-to analysis. This analysis determines
182 // the potential sources of each buffer in each instruction's output.
183 class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
184  public:
185   // Runs points-to analysis on 'module'.
186   static StatusOr<std::unique_ptr<TuplePointsToAnalysis>> Run(
187       const HloModule* module);
188 
189   // Return the points-to set of an instruction. This describes the potential
190   // sources of each buffer in the instruction's output.
191   const PointsToSet& GetPointsToSet(
192       const HloInstruction* hlo_instruction) const;
193 
194   // Returns the logical buffer with the given ID.
195   const LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const;
196 
197   // Returns the buffer defined at the given instruction and index. An error is
198   // returned if no buffer is defined at that point.
199   StatusOr<const LogicalBuffer*> GetBufferDefinedAt(
200       const HloInstruction* instruction, const ShapeIndex& index) const;
201 
202   // Return a (possibly empty) vector containing all BufferAliases of the given
203   // logical buffer The buffer alias set is the inverse of the points-to set.
204   // That is, LogicalBuffer B is in the points-to set of instruction I at index
205   // N iff instruction I, index N is a BufferAlias of B.
206   using BufferAliasVector = absl::InlinedVector<BufferAlias, 1>;
207   const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const;
208 
209   // Returns the number of logical buffers in the module
num_logical_buffers()210   LogicalBuffer::Id num_logical_buffers() const {
211     return logical_buffer_analysis_->num_logical_buffers();
212   }
213 
214   // Return a the logical buffer with id "id" in the module. Iteration
215   // over all logical buffers is usually done with something like:
216   //
217   // for (LogicalBuffer:Id id = 0; id < points_to.num_logical_buffers(); id++){
218   //   const auto& buffer = points_to.logical_buffer(id);
219   //   ... do something with buffer ...
220   // }
logical_buffer(LogicalBuffer::Id id)221   LogicalBuffer& logical_buffer(LogicalBuffer::Id id) const {
222     return logical_buffer_analysis_->GetBuffer(id);
223   }
224 
225   // Returns a vector of buffers that the instruction produces. Most
226   // instructions produce a single buffer (the top-level buffer), some produce
227   // no buffers (eg bitcast), and some produce more than one buffer (eg,
228   // tuple-shaped parameters).
229   using BufferDefinitionVector = absl::InlinedVector<const LogicalBuffer*, 1>;
230   const BufferDefinitionVector& GetBuffersDefinedByInstruction(
231       const HloInstruction* instruction) const;
232 
233   // Returns true if the given instruction defines a buffer at the given index.
234   bool InstructionDefinesBufferAtIndex(const HloInstruction* instruction,
235                                        const ShapeIndex& index) const;
236 
237   // Returns an OK status if the given buffer is defined by instruction
238   // 'buffer.instruction()' at index 'buffer.index()' and if the given buffer
239   // matches the TuplePointsToAnalysis' LogicalBuffer with 'buffer.id'. Returns
240   // an FailedPrecondition error status otherwise. An example of a LogicalBuffer
241   // which is not defined is a tuple element in a Tuple instruction. In this
242   // case, the Tuple instruction does not define the LogicalBuffer, rather that
243   // index aliases one of its operands.
244   Status VerifyBuffer(const LogicalBuffer& buffer) const;
245 
246   Status DefaultAction(HloInstruction* hlo_instruction) override;
247   Status HandleTuple(HloInstruction* tuple) override;
248   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
249   Status HandleAsyncStart(HloInstruction* async_start) override;
250   Status HandleAsyncUpdate(HloInstruction* async_update) override;
251   Status HandleAsyncDone(HloInstruction* async_done) override;
252   Status HandleBitcast(HloInstruction* bitcast) override;
253   Status HandleDomain(HloInstruction* domain) override;
254   Status HandleCopy(HloInstruction* copy) override;
255   Status HandleCopyStart(HloInstruction* copy_start) override;
256   Status HandleCopyDone(HloInstruction* copy_done) override;
257   Status HandleRecvDone(HloInstruction* recv_done) override;
258   Status HandleSend(HloInstruction* send) override;
259   Status HandleAddDependency(HloInstruction* add_dependency) override;
260   Status HandleCustomCall(HloInstruction* custom_call) override;
261   Status HandleOptimizationBarrier(HloInstruction* barrier) override;
262 
263   std::string ToString() const;
264 
265   // Returns true if 'user' cannot possibly use the buffer at 'index' in
266   // 'operand'. Returns false otherwise.
267   //
268   // REQUIRES: 'operand' is an operand of 'user'.
269   bool DoesNotUseOperandBuffer(const HloInstruction* operand,
270                                const ShapeIndex& index,
271                                const HloInstruction* user) const;
272 
273  private:
TuplePointsToAnalysis(const HloModule * module,std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis)274   explicit TuplePointsToAnalysis(
275       const HloModule* module,
276       std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis)
277       : module_(module),
278         logical_buffer_analysis_(std::move(logical_buffer_analysis)) {}
279 
280   // Perform the analysis. Should be called immediately after constructing the
281   // object and before calling GetPointsToSet.
282   Status Analyze();
283 
284   // Populates instruction-defined buffers and aliases for each instruction
285   // in 'instructions'.
286   Status PopulateDefinedBuffersAndAliases(
287       const decltype(std::declval<HloComputation>()
288                          .instructions())& instructions);
289 
290   // Creates an empty PointsToSet in the points_to_ map for the given
291   // instruction.
292   PointsToSet& CreateEmptyPointsToSet(const HloInstruction* instruction);
293 
294   // Creates a PointsToSet in the points_to_ map for 'instruction' which is a
295   // copy of the existing PointsToSet for 'src'.
296   PointsToSet& CreateCopiedPointsToSet(const HloInstruction* instruction,
297                                        const HloInstruction* src);
298 
299   // Adds the buffers defined by the given instruction to the given vector.
300   Status GatherBuffersDefinedByInstruction(const HloInstruction* instruction,
301                                            BufferDefinitionVector* buffers);
302 
303   // Print points-to set for 'instruction' to 'output'.
304   void InstructionToString(const HloInstruction* instruction,
305                            std::string* output) const;
306 
307   // Information kept per instruction
308   struct PerInstruction {
309     std::unique_ptr<PointsToSet> points_to_set;
310     // Empirically, ~92% of instructions have 1
311     // instruction_defined_buffer, and 99% have 0 or 1
312     BufferDefinitionVector instruction_defined_buffers;
313   };
314 
PerInst(const HloInstruction * inst)315   const PerInstruction* PerInst(const HloInstruction* inst) const {
316     int id = inst->unique_id();
317     DCHECK_GE(id, 0);
318     auto iter = per_instruction_.find(id);
319     if (iter == per_instruction_.end()) {
320       LOG(FATAL) << "Expected per-instruction information to already exist";
321     } else {
322       return iter->second.get();
323     }
324   }
PerInst(const HloInstruction * inst)325   PerInstruction* PerInst(const HloInstruction* inst) {
326     int id = inst->unique_id();
327     DCHECK_GE(id, 0);
328     auto iter = per_instruction_.find(id);
329     if (iter == per_instruction_.end()) {
330       return per_instruction_.emplace(id, std::make_unique<PerInstruction>())
331           .first->second.get();
332     } else {
333       return iter->second.get();
334     }
335   }
336 
337   std::vector<std::pair<HloInstruction*, int64_t>>
338   GetAllUsesOfInstructionAtIndex(HloInstruction* instruction,
339                                  const ShapeIndex& index) const;
340   bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand,
341                                     const ShapeIndex& operand_index,
342                                     HloInstruction* fusion,
343                                     const int64_t use_operand_index) const;
344 
345   // The module this analysis is performed on.
346   const HloModule* module_;
347 
348   // The logical buffers for this module.
349   const std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis_;
350 
351   // A map from instruction->unique_id() to
352   absl::flat_hash_map<int, std::unique_ptr<PerInstruction>> per_instruction_;
353 
354   // A map from LogicalBuffer->id() to alias information about that logical
355   // buffer
356   std::vector<BufferAliasVector> logical_buffer_aliases_;
357 
358   TuplePointsToAnalysis(const TuplePointsToAnalysis&) = delete;
359   TuplePointsToAnalysis& operator=(const TuplePointsToAnalysis&) = delete;
360 };
361 
362 }  // namespace xla
363 
364 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
365