• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
18 
19 #include <stddef.h>
20 
21 #include <iosfwd>
22 #include <memory>
23 #include <set>
24 #include <string>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/logical_buffer.h"
34 #include "tensorflow/compiler/xla/service/logical_buffer_analysis.h"
35 #include "tensorflow/compiler/xla/shape_tree.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/gtl/compactptrset.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace xla {
45 
46 // A class describing the source(s) of the Buffer(s) contained in the output of
47 // a particular HLO instruction. The structure of PointsToSet mirrors the
48 // structure of the instruction's shape, which may be an arbitrary tree (eg, a
49 // nested tuple). Each node in this tree corresponds to a single buffer in the
50 // instruction's output and contains the set of Buffers which might define
51 // the corresponding buffer.
52 class PointsToSet {
53  public:
54   // Construct our ShapeTree with a pointer rather than a reference to a Shape
55   // because this is very hot code, and copying (and then destroying) all these
56   // Shapes is slow.
PointsToSet(const Shape * shape)57   explicit PointsToSet(const Shape* shape) : tree_(shape) {}
58 
59   // Returns true if any points-to sets for any subshape element is not a
60   // singleton.
61   bool IsAmbiguous() const;
62 
63   // Returns true if no LogicalBuffer appears in more than one points-to set of
64   // the shape nodes.
65   bool IsDistinct() const;
66 
67   // Returns the total number of different LogicalBuffers contained in this
68   // object. This is equal to CreateFlattenedSet().size().
69   size_t size() const;
70 
71   // Creates a set containing the union of all LogicalBuffers contained in the
72   // PointsToSet.
73   using BufferSet = tensorflow::gtl::CompactPointerSet<const LogicalBuffer*>;
74   BufferSet CreateFlattenedSet() const;
75 
76   // Returns true if the given buffer is in the points-to set at the given
77   // index.
78   bool ContainsBufferAtIndex(const LogicalBuffer& buffer,
79                              const ShapeIndex& index) const;
80 
81   // Returns true if the given buffer is in the points-to set at any index.
82   bool ContainsBuffer(const LogicalBuffer& buffer) const;
83 
84   // Adds the given buffer to the points-to set at the given index. This is a
85   // nop if the buffer already is in the set at that index.
86   void AddPointedToBuffer(const LogicalBuffer& buffer, const ShapeIndex& index);
87 
88   // For the subshape at the given index (where index is defined as in
89   // ShapeUtil::GetSubshape) this method returns the set of HLO instructions
90   // which may produce the tuple subshape at that index. For example, given:
91   //
92   // %tuple1 = tuple(...)
93   // %tuple2 = tuple(...)
94   // %select = select(%tuple1, %tuple2)
95   // %nested_tuple = tuple(%select, %tuple1)
96   //
97   // These are the values for tuple_sources() for the PointsToSet of
98   // %nested_tuple:
99   //
100   // tuple_sources({}) = {%nested_tuple}
101   // tuple_sources({0}) = {%tuple1, %tuple2}
102   // tuple_sources({1}) = {%tuple1}
103   //
104   // tuple_sources() at the index of an array shape (not a tuple) returns the
105   // empty set. The instructions in the set returned by tuple_sources
106   // necessarily are either Tuple instructions, constants, or parameters.
107   using SourceSet = tensorflow::gtl::CompactPointerSet<HloInstruction*>;
108   const SourceSet& tuple_sources(const ShapeIndex& index) const;
109 
110   // Add a tuple source instruction for the given index.
111   void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple);
112 
113   using BufferList = absl::InlinedVector<const LogicalBuffer*, 1>;
114 
115   // Return the list of logical buffers for the subshape at index.
element(const ShapeIndex & index)116   const BufferList& element(const ShapeIndex& index) const {
117     return tree_.element(index).buffers;
118   }
mutable_element(const ShapeIndex & index)119   BufferList* mutable_element(const ShapeIndex& index) {
120     return &tree_.mutable_element(index)->buffers;
121   }
122 
123   // Call fn(index, buflist) for every subshape index.
124   template <typename Fn>
ForEachElement(const Fn & fn)125   void ForEachElement(const Fn& fn) const {
126     tree_.ForEachElement([&fn](const ShapeIndex& index, const Elem& elem) {
127       fn(index, elem.buffers);
128     });
129   }
130   template <typename Fn>
ForEachMutableElement(const Fn & fn)131   void ForEachMutableElement(const Fn& fn) {
132     tree_.ForEachMutableElement([&fn](const ShapeIndex& index, Elem* elem) {
133       fn(index, &elem->buffers);
134     });
135   }
136   template <typename Fn>
ForEachElementWithStatus(const Fn & fn)137   Status ForEachElementWithStatus(const Fn& fn) const {
138     return tree_.ForEachElementWithStatus(
139         [&fn](const ShapeIndex& index, const Elem& elem) {
140           return fn(index, elem.buffers);
141         });
142   }
143 
144  private:
145   struct Elem {
146     BufferList buffers;
147     SourceSet tuple_sources;
148   };
149   ShapeTree<Elem> tree_;
150 
151   // PointsToSet contains references (const LogicalBuffer*) to elements within
152   // TuplePointsToAnalysis, so disable copying.
153   TF_DISALLOW_COPY_AND_ASSIGN(PointsToSet);
154 };
155 
156 // This class describes a particular subshape in a computation (instruction and
157 // shape index) and the logical buffer which may be a source of the subshape
158 // value.
159 class BufferAlias {
160  public:
BufferAlias(HloInstruction * instruction,const ShapeIndex & index)161   BufferAlias(HloInstruction* instruction, const ShapeIndex& index)
162       : instruction_(instruction), index_(index) {}
163 
164   // Return the instruction/index of the subshape.
instruction()165   HloInstruction* instruction() const { return instruction_; }
index()166   const ShapeIndex& index() const { return index_; }
167 
168   bool operator==(const BufferAlias& other) const {
169     return instruction_ == other.instruction_ && index_ == other.index_;
170   }
171   bool operator!=(const BufferAlias& other) const { return !(*this == other); }
172 
173   string ToString() const;
174 
175  private:
176   HloInstruction* instruction_;
177   ShapeIndex index_;
178 };
179 
180 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias);
181 
182 // DFS visitor that performs tuple points-to analysis. This analysis determines
183 // the potential sources of each buffer in each instruction's output.
184 class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
185  public:
186   // Runs points-to analysis on 'module'.
187   static StatusOr<std::unique_ptr<TuplePointsToAnalysis>> Run(
188       const HloModule* module);
189 
190   // Return the points-to set of an instruction. This describes the potential
191   // sources of each buffer in the instruction's output.
192   const PointsToSet& GetPointsToSet(
193       const HloInstruction* hlo_instruction) const;
194 
195   // Returns the logical buffer with the given ID.
196   const LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const;
197 
198   // Returns the buffer defined at the given instruction and index. An error is
199   // returned if no buffer is defined at that point.
200   StatusOr<const LogicalBuffer*> GetBufferDefinedAt(
201       const HloInstruction* instruction, const ShapeIndex& index) const;
202 
203   // Return a (possibly empty) vector containing all BufferAliases of the given
204   // logical buffer The buffer alias set is the inverse of the points-to set.
205   // That is, LogicalBuffer B is in the points-to set of instruction I at index
206   // N iff instruction I, index N is a BufferAlias of B.
207   using BufferAliasVector = absl::InlinedVector<BufferAlias, 1>;
208   const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const;
209 
210   // Returns the number of logical buffers in the module
num_logical_buffers()211   LogicalBuffer::Id num_logical_buffers() const {
212     return logical_buffer_analysis_->num_logical_buffers();
213   }
214 
215   // Return a the logical buffer with id "id" in the module. Iteration
216   // over all logical buffers is usually done with something like:
217   //
218   // for (LogicalBuffer:Id id = 0; id < points_to.num_logical_buffers(); id++){
219   //   const auto& buffer = points_to.logical_buffer(id);
220   //   ... do something with buffer ...
221   // }
logical_buffer(LogicalBuffer::Id id)222   LogicalBuffer& logical_buffer(LogicalBuffer::Id id) const {
223     return logical_buffer_analysis_->GetBuffer(id);
224   }
225 
226   // Returns a vector of buffers that the instruction produces. Most
227   // instructions produce a single buffer (the top-level buffer), some produce
228   // no buffers (eg bitcast), and some produce more than one buffer (eg,
229   // tuple-shaped parameters).
230   using BufferDefinitionVector = absl::InlinedVector<const LogicalBuffer*, 1>;
231   const BufferDefinitionVector& GetBuffersDefinedByInstruction(
232       const HloInstruction* instruction) const;
233 
234   // Returns true if the given instruction defines a buffer at the given index.
235   bool InstructionDefinesBufferAtIndex(const HloInstruction* instruction,
236                                        const ShapeIndex& index) const;
237 
238   // Returns an OK status if the given buffer is defined by instruction
239   // 'buffer.instruction()' at index 'buffer.index()' and if the given buffer
240   // matches the TuplePointsToAnalysis' LogicalBuffer with 'buffer.id'. Returns
241   // an FailedPrecondition error status otherwise. An example of a LogicalBuffer
242   // which is not defined is a tuple element in a Tuple instruction. In this
243   // case, the Tuple instruction does not define the LogicalBuffer, rather that
244   // index aliases one of its operands.
245   Status VerifyBuffer(const LogicalBuffer& buffer) const;
246 
247   Status DefaultAction(HloInstruction* hlo_instruction) override;
248   Status HandleTuple(HloInstruction* tuple) override;
249   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
250   Status HandleBitcast(HloInstruction* bitcast) override;
251   Status HandleDomain(HloInstruction* domain) override;
252   Status HandleCopy(HloInstruction* copy) override;
253   Status HandleCopyStart(HloInstruction* copy_start) override;
254   Status HandleCopyDone(HloInstruction* copy_done) override;
255   Status HandleRecvDone(HloInstruction* recv_done) override;
256   Status HandleSend(HloInstruction* send) override;
257   Status HandleTupleSelect(HloInstruction* tuple_select) override;
258   Status HandleAddDependency(HloInstruction* add_dependency) override;
259   Status HandleCustomCall(HloInstruction* custom_call) override;
260 
261   string ToString() const;
262 
263   // Returns true if 'user' cannot possibly use the buffer at 'index' in
264   // 'operand'. Returns false otherwise.
265   //
266   // REQUIRES: 'operand' is an operand of 'user'.
267   bool DoesNotUseOperandBuffer(const HloInstruction* operand,
268                                const ShapeIndex& index,
269                                const HloInstruction* user) const;
270 
271  private:
TuplePointsToAnalysis(const HloModule * module,std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis)272   explicit TuplePointsToAnalysis(
273       const HloModule* module,
274       std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis)
275       : module_(module),
276         logical_buffer_analysis_(std::move(logical_buffer_analysis)) {}
277 
278   // Perform the analysis. Should be called immediately after constructing the
279   // object and before calling GetPointsToSet.
280   Status Analyze();
281 
282   // Populates instruction-defined buffers and aliases for each instruction
283   // in 'instructions'.
284   Status PopulateDefinedBuffersAndAliases(const decltype(
285       std::declval<HloComputation>().instructions())& instructions);
286 
287   // Creates an empty PointsToSet in the points_to_ map for the given
288   // instruction.
289   PointsToSet& CreateEmptyPointsToSet(const HloInstruction* instruction);
290 
291   // Creates a PointsToSet in the points_to_ map for 'instruction' which is a
292   // copy of the existing PointsToSet for 'src'.
293   PointsToSet& CreateCopiedPointsToSet(const HloInstruction* instruction,
294                                        const HloInstruction* src);
295 
296   // Adds the buffers defined by the given instruction to the given vector.
297   Status GatherBuffersDefinedByInstruction(const HloInstruction* instruction,
298                                            BufferDefinitionVector* buffers);
299 
300   // Print points-to set for 'instruction' to 'output'.
301   void InstructionToString(const HloInstruction* instruction,
302                            string* output) const;
303 
304   // Information kept per instruction
305   struct PerInstruction {
306     std::unique_ptr<PointsToSet> points_to_set;
307     // Empirically, ~92% of instructions have 1
308     // instruction_defined_buffer, and 99% have 0 or 1
309     BufferDefinitionVector instruction_defined_buffers;
310   };
311 
PerInst(const HloInstruction * inst)312   const PerInstruction* PerInst(const HloInstruction* inst) const {
313     int id = inst->unique_id();
314     DCHECK_GE(id, 0);
315     auto iter = per_instruction_.find(id);
316     if (iter == per_instruction_.end()) {
317       LOG(FATAL) << "Expected per-instruction information to already exist";
318     } else {
319       return iter->second.get();
320     }
321   }
PerInst(const HloInstruction * inst)322   PerInstruction* PerInst(const HloInstruction* inst) {
323     int id = inst->unique_id();
324     DCHECK_GE(id, 0);
325     auto iter = per_instruction_.find(id);
326     if (iter == per_instruction_.end()) {
327       return per_instruction_.emplace(id, absl::make_unique<PerInstruction>())
328           .first->second.get();
329     } else {
330       return iter->second.get();
331     }
332   }
333 
334   std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
335       HloInstruction* instruction, const ShapeIndex& index) const;
336   bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand,
337                                     const ShapeIndex& operand_index,
338                                     HloInstruction* fusion,
339                                     const int64 use_operand_index) const;
340 
341   // The module this analysis is performed on.
342   const HloModule* module_;
343 
344   // The logical buffers for this module.
345   const std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis_;
346 
347   // A map from instruction->unique_id() to
348   absl::flat_hash_map<int, std::unique_ptr<PerInstruction>> per_instruction_;
349 
350   // A map from LogicalBuffer->id() to alias information about that logical
351   // buffer
352   std::vector<BufferAliasVector> logical_buffer_aliases_;
353 
354   TF_DISALLOW_COPY_AND_ASSIGN(TuplePointsToAnalysis);
355 };
356 
357 }  // namespace xla
358 
359 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
360