• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Analysis for determining the possible set of values for all positions
17 // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped
18 // tracking values across computation boundaries.
19 
20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
21 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
22 
23 #include <iterator>
24 #include <memory>
25 #include <string>
26 #include <unordered_map>
27 #include <vector>
28 
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/service/call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/hlo_phi_graph.h"
36 #include "tensorflow/compiler/xla/service/hlo_value.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/platform/macros.h"
43 
44 namespace xla {
45 
46 // Analysis which identifies all HLO values and their uses in an HLO module.
47 class HloDataflowAnalysis {
48  public:
49   // Infrastructure for passing may-alias hints: HLO passes can populate the
50   // may-alias table. If an empty optional is returned, default rules are used.
51   //
52   // Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be
53   // overriden using backend-specific overrides.
54   //
55   // The first parameter of the function should be the instruction, the
56   // second parameter should be an operand of the instruction. The third
57   // parameter should be the output index of the instruction.
58   using CanShareBuffer = std::function<absl::optional<bool>(
59       const HloInstruction* instr, const HloInstruction* operand,
60       const ShapeIndex& user_index)>;
61 
62   // Runs dataflow analysis on the given module. Parameters:
63   //
64   //   ssa_form : If true then new values are defined at the merge points of
65   //     kWhile instructions. Abusing nomenclature somewhat, we call these "phi
66   //     values".  The merge is formed by the init value and loop backedge. The
67   //     SSA form is minimal in that a new phi value is defined only if the
68   //     merge point is reachable by multiple different values. The SSA form is
69   //     also in loop-closed form in that no values defined inside of a loop
70   //     (while body) is used outside of the loop. Example use of this ssa_form
71   //     mode is to reason about live range interference of buffers.
72   //
73   //     If ssa_form is false, then merge points do not define new
74   //     values. Rather, the HloValueSet for the merge point contains the union
75   //     of the merged HloValues.
76   //
77   //   bitcast_defines_value : If true then the Bitcast HLO instruction defines
78   //     a new HLO value in the analysis. If false then Bitcast forwards the
79   //     value of its operand.
80   static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
81       const HloModule& module, bool ssa_form = false,
82       bool bitcast_defines_value = false,
83       const CanShareBuffer& can_share_buffer = nullptr);
84 
85   static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst);
86 
87   // Returns true if 'instruction' defines an HLO value at the given shape index
88   // of its output.
89   bool ValueIsDefinedAt(const HloInstruction* instruction,
90                         const ShapeIndex& index = {}) const;
91 
92   // Returns the HloValue defined by 'instruction' at the given shape index of
93   // its output.
94   //
95   // Precondition: ValueIsDefinedAt is true for this instruction and index.
96   const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
97                                     const ShapeIndex& index = {}) const;
98   HloValue& GetValueDefinedAt(const HloInstruction* instruction,
99                               const ShapeIndex& index = {});
100 
101   // Returns the InstructionValueSet for the given instruction.
102   const InstructionValueSet& GetInstructionValueSet(
103       const HloInstruction* instruction) const;
104   InstructionValueSet& GetInstructionValueSet(
105       const HloInstruction* instruction);
106 
107   // Returns all values that are contained in the output of this instruction in
108   // a flattened set.
109   HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const;
110 
111   // Returns the HloValueSet for the given instruction at the given index or the
112   // given position.
113   const HloValueSet& GetValueSet(const HloInstruction* instruction,
114                                  const ShapeIndex& index = {}) const;
115   const HloValueSet& GetValueSet(const HloPosition& position) const;
116   HloValueSet& GetValueSet(const HloPosition& position);
117   HloValueSet& GetValueSet(const HloInstruction* instruction,
118                            const ShapeIndex& index = {});
119 
120   // Returns the unique value in the HloValueSet at the given instruction and
121   // shape index. CHECKs if the value set does not contain a exactly one value.
122   const HloValue& GetUniqueValueAt(const HloInstruction* instruction,
123                                    const ShapeIndex& index = {}) const {
124     return GetValueSet(instruction, index).GetUniqueValue();
125   }
126   HloValue& GetUniqueValueAt(const HloInstruction* instruction,
127                              const ShapeIndex& index = {}) {
128     return GetValue(GetValueSet(instruction, index).GetUniqueValue().id());
129   }
130 
131   // Returns the HloValue with the given Id.
132   const HloValue& GetValue(HloValue::Id value_id) const;
133   HloValue& GetValue(HloValue::Id value_id);
134 
135   // Returns the total number of HloValues.
value_count()136   int64 value_count() const { return values_.size(); }
137 
138   // Returns a vector of all HloValues stabily sorted by HloValue::Id.
values()139   const std::vector<HloValue*>& values() const { return values_vector_; }
140 
141   // Returns the call graph used for computing the dataflow.
call_graph()142   const CallGraph& call_graph() const { return *call_graph_; }
143 
144   string ToString() const;
145 
146   // Returns true if 'user' cannot possibly use the buffer at 'index' in
147   // 'operand'. Returns false otherwise.
148   //
149   // 'operand' does not have to be an operand of 'user'. This can be the
150   // case with indirect uses.
151   bool DoesNotUseOperandBuffer(const HloInstruction* operand,
152                                const ShapeIndex& index,
153                                const HloInstruction* user) const;
154 
155   // Returns true if 'user' (at 'user_index') can share a buffer with its
156   // operand 'operand' (at 'operand_index'). Returns false otherwise.
157   //
158   // REQUIRES: 'operand' is an operand of 'user'.
159   bool CanShareOperandBufferWithUser(HloInstruction* operand,
160                                      const ShapeIndex& operand_index,
161                                      HloInstruction* user,
162                                      const ShapeIndex& user_index) const;
163 
module()164   const HloModule& module() const { return module_; }
165 
166   // Returns true if the operation is an in-place operation and its operand 0
167   // must alias with the output.
168   static bool IsInPlaceOperation(HloOpcode opcode);
169 
170   // Returns true if the operation is the start/done of an asynchronous
171   // operation, where the buffer used/produced by the op needs to stay alive
172   // until the asynchronous operation completes.
173   static bool IsAsynchronousOperationStart(HloOpcode opcode);
174   static bool IsAsynchronousOperationDone(HloOpcode opcode);
175 
176   // Returns a vector consisting of the HloUse (operand number and shape index)
177   // and output shape index of the in-place operations within this HLO.
178   static std::vector<std::pair<HloUse, ShapeIndex>> GetInPlaceInputOutputPairs(
179       HloInstruction* instruction);
180 
181  protected:
182   HloDataflowAnalysis(const HloModule& module, bool ssa_form,
183                       bool bitcast_defines_value = false,
184                       const CanShareBuffer& can_share_buffer = nullptr);
185 
186   // 1. During value propagation (Propagate function), always create phi
187   // values once it see multiple inputs merging at the same point. It then
188   // records those phi values as well as their inputs in a phi graph.
189   //
190   // 2. Post value propagation, Dataflow analysis can then do certain
191   // optimization(OptimizePhiValues) on the phi graph to prune uncessary phi
192   // nodes.
193   //
194   // Note that this applies in SSA form, and Both of the functions are
195   // guaranteed to exit.
196   //
197   void OptimizePhiValues();
198 
199   // Returns a new HloValue defined at the given instruction and shape index.
200   HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
201                         bool is_phi);
202 
203   // Marks the HloValue with the given ID for deletion.
204   void MarkValueForDeletion(HloValue::Id value_id);
205 
206   // Deletes all HloValues marked for deletion. Should be called after
207   // propagation is complete.
208   void DeleteMarkedValues();
209 
210   // Constructs and initializes the InstructionValueSets of all instructions to
211   // contain exactly the HloValues defined by each instruction. These values can
212   // then propagated throughout the HLO graph by calling Propagate.
213   Status InitializeInstructionValueSets();
214 
215   // Updates the value set of the given instruction based on the values flowing
216   // into the instruction (operands and cross-computation dataflow).
217   bool UpdateInstructionValueSet(HloInstruction* instruction);
218 
219   // Updates the value set for a particular instruction type. Returns whether
220   // the instruction value set changed.
221   bool UpdateBitcastValueSet(HloInstruction* bitcast);
222   bool UpdateCallValueSet(HloInstruction* call);
223   bool UpdateConditionalValueSet(HloInstruction* conditional);
224   bool UpdateCopyValueSet(HloInstruction* copy);
225   bool UpdateCustomCallValueSet(HloInstruction* custom_call);
226   bool UpdateDomainValueSet(HloInstruction* domain);
227   bool UpdateGetTupleElementValueSet(HloInstruction* gte);
228   bool UpdateParameterValueSet(HloInstruction* parameter);
229   bool UpdateCopyStartValueSet(HloInstruction* copy_start);
230   bool UpdateCopyDoneValueSet(HloInstruction* copy_done);
231   bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
232   bool UpdateTupleSelectValueSet(HloInstruction* select);
233   bool UpdateSendValueSet(HloInstruction* send);
234   bool UpdateSetDimensionSizeValueSet(HloInstruction* set_dimension_size);
235   bool UpdateTupleValueSet(HloInstruction* tuple);
236   bool UpdateWhileValueSet(HloInstruction* xla_while);
237   bool UpdateAddDependencyValueSet(HloInstruction* add_dependency);
238   bool UpdateAllGatherStartValueSet(HloInstruction* all_gather_start);
239   bool UpdateAllGatherDoneValueSet(HloInstruction* all_gather_done);
240   bool UpdateAllReduceStartValueSet(HloInstruction* all_reduce_start);
241   bool UpdateAllReduceDoneValueSet(HloInstruction* all_reduce_done);
242   bool UpdateCollectivePermuteStartValueSet(
243       HloInstruction* collective_permute_start);
244   bool UpdateCollectivePermuteDoneValueSet(
245       HloInstruction* collective_permute_done);
246 
247   // Propagates the dataflow through the module. In particular, it propagates
248   // the HloValueSet from its defining instruction to the users of the
249   // instructions.
250   void Propagate();
251 
252   // Returns the result of the SSA Phi function applied to the given inputs at
253   // the given instruction.
254   bool Phi(HloInstruction* instruction,
255            absl::Span<const InstructionValueSet* const> inputs);
256 
257   // Updates the positions of the HloValues in the output of the given
258   // instruction. This should be called after the instruction value set of
259   // 'instruction' has been changed. 'prev_value_set' must point to the previous
260   // state of the value set prior to the change. 'prev_value_set' may be null if
261   // this is the first time positions are being computed. The previous state is
262   // necessary to efficiently remove positions which have been eliminated due to
263   // changes in the instructions' InstructionValueSet.
264   void UpdatePositionsOfValuesAt(
265       HloInstruction* instruction, const InstructionValueSet& new_value_set,
266       const InstructionValueSet* prev_value_set = nullptr);
267 
268   // Verifies various invariants of the dataflow analysis.
269   Status Verify() const;
270 
271   const HloModule& module_;
272   const bool ssa_form_;
273   const bool bitcast_defines_value_;
274 
275   std::unique_ptr<CallGraph> call_graph_;
276 
277   // The map of all HloValues in the module. We pass around pointers to the
278   // mapped HloValues, so the underlying container must keep them valid despite
279   // mutations touching other map entries.
280   std::unordered_map<HloValue::Id, HloValue> values_;
281 
282   // A map from instruction to InstructionValueSet.
283   std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
284 
285   // Values marked for deletion during construction. We don't delete them
286   // immediately because references to them may remain in ValueSets temporarily
287   // during propagation. After construction, these values are deleted.
288   std::vector<HloValue::Id> value_ids_to_delete_;
289 
290   // A vector containing all HloValues sorted by HloValue::Id.
291   std::vector<HloValue*> values_vector_;
292 
293   // The Id to use for the next HloValue.
294   HloValue::Id next_value_id_ = 0;
295 
296   // An explicit graph holding phi values and edges.
297   PhiGraph phi_graph_;
298 
299   // Backend specific function that decides whether an instruction can share
300   // a buffer with its operand.
301   CanShareBuffer can_share_buffer_ = nullptr;
302 };
303 
304 }  // namespace xla
305 
306 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
307