• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Analysis for determining the possible set of values for all positions
17 // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped
18 // tracking values across computation boundaries.
19 
20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
21 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
22 
23 #include <memory>
24 #include <string>
25 #include <unordered_map>
26 #include <vector>
27 
28 #include "absl/types/span.h"
29 #include "tensorflow/compiler/xla/service/call_graph.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/hlo_value.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/macros.h"
39 
40 namespace xla {
41 
42 // Analysis which identifies all HLO values and their uses in an HLO module.
43 class HloDataflowAnalysis {
44  public:
45   // Infrastructure for passing may-alias hints: HLO passes can populate the
46   // may-alias table. If an empty optional is returned, default rules are used.
47   //
48   // The first parameter of the function should be the instruction, the
49   // second parameter should be an operand of the instruction. The third
50   // parameter should be the output index of the instruction.
51   using CanShareBuffer = std::function<absl::optional<bool>(
52       const HloInstruction* instr, const HloInstruction* operand,
53       const ShapeIndex& user_index)>;
54 
55   // Runs dataflow analysis on the given module. Parameters:
56   //
57   //   ssa_form : If true then new values are defined at the merge points of
58   //     kWhile instructions. Abusing nomenclature somewhat, we call these "phi
59   //     values".  The merge is formed by the init value and loop backedge. The
60   //     SSA form is minimal in that a new phi value is defined only if the
61   //     merge point is reachable by multiple different values. The SSA form is
62   //     also in loop-closed form in that no values defined inside of a loop
63   //     (while body) is used outside of the loop.
64   //
65   //     If ssa_form is false, then merge points do not define new
66   //     values. Rather, the HloValueSet for the merge point contains the union
67   //     of the merged HloValues.
68   //
69   //   bitcast_defines_value : If true then the Bitcast HLO instruction defines
70   //     a new HLO value in the analysis. If false then Bitcast forwards the
71   //     value of its operand.
72   static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
73       const HloModule& module, bool ssa_form = false,
74       bool bitcast_defines_value = false,
75       const CanShareBuffer& can_share_buffer = nullptr);
76 
77   static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst);
78 
79   // Returns true if 'instruction' defines an HLO value at the given shape index
80   // of its output.
81   bool ValueIsDefinedAt(const HloInstruction* instruction,
82                         const ShapeIndex& index = {}) const;
83 
84   // Returns the HloValue defined by 'instruction' at the given shape index of
85   // its output.
86   //
87   // Precondition: ValueIsDefinedAt is true for this instruction and index.
88   const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
89                                     const ShapeIndex& index = {}) const;
90   HloValue& GetValueDefinedAt(const HloInstruction* instruction,
91                               const ShapeIndex& index = {});
92 
93   // Returns the InstructionValueSet for the given instruction.
94   const InstructionValueSet& GetInstructionValueSet(
95       const HloInstruction* instruction) const;
96   InstructionValueSet& GetInstructionValueSet(
97       const HloInstruction* instruction);
98 
99   // Returns all values that are contained in the output of this instruction in
100   // a flattened set.
101   HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const;
102 
103   // Returns the HloValueSet for the given instruction at the given index or the
104   // given position.
105   const HloValueSet& GetValueSet(const HloInstruction* instruction,
106                                  const ShapeIndex& index = {}) const;
107   const HloValueSet& GetValueSet(const HloPosition& position) const;
108   HloValueSet& GetValueSet(const HloPosition& position);
109   HloValueSet& GetValueSet(const HloInstruction* instruction,
110                            const ShapeIndex& index = {});
111 
112   // Returns the unique value in the HloValueSet at the given instruction and
113   // shape index. CHECKs if the value set does not contain a exactly one value.
114   const HloValue& GetUniqueValueAt(const HloInstruction* instruction,
115                                    const ShapeIndex& index = {}) const {
116     return GetValueSet(instruction, index).GetUniqueValue();
117   }
118   HloValue& GetUniqueValueAt(const HloInstruction* instruction,
119                              const ShapeIndex& index = {}) {
120     return GetValue(GetValueSet(instruction, index).GetUniqueValue().id());
121   }
122 
123   // Returns the HloValue with the given Id.
124   const HloValue& GetValue(HloValue::Id value_id) const;
125   HloValue& GetValue(HloValue::Id value_id);
126 
127   // Returns the total number of HloValues.
value_count()128   int64 value_count() const { return values_.size(); }
129 
130   // Returns a vector of all HloValues stabily sorted by HloValue::Id.
values()131   const std::vector<HloValue*>& values() const { return values_vector_; }
132 
133   // Returns the call graph used for computing the dataflow.
call_graph()134   const CallGraph& call_graph() const { return *call_graph_; }
135 
136   string ToString() const;
137 
138   // Returns true if 'user' cannot possibly use the buffer at 'index' in
139   // 'operand'. Returns false otherwise.
140   //
141   // 'operand' does not have to be an operand of 'user'. This can be the case
142   // with indirect uses.
143   bool DoesNotUseOperandBuffer(const HloInstruction* operand,
144                                const ShapeIndex& index,
145                                const HloInstruction* user) const;
146 
147   // Returns true if 'user' (at 'user_index') can share a buffer with its
148   // operand 'operand' (at 'operand_index'). Returns false otherwise.
149   //
150   // REQUIRES: 'operand' is an operand of 'user'.
151   bool CanShareOperandBufferWithUser(HloInstruction* operand,
152                                      const ShapeIndex& operand_index,
153                                      HloInstruction* user,
154                                      const ShapeIndex& user_index) const;
155 
module()156   const HloModule& module() const { return module_; }
157 
158  protected:
159   HloDataflowAnalysis(const HloModule& module, bool ssa_form,
160                       bool bitcast_defines_value = false,
161                       const CanShareBuffer& can_share_buffer = nullptr);
162 
163   // Returns a new HloValue defined at the given instruction and shape index.
164   HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
165                         bool is_phi = false);
166 
167   // Marks the HloValue with the given ID for deletion.
168   void MarkValueForDeletion(HloValue::Id value_id);
169 
170   // Deletes all HloValues marked for deletion. Should be called after
171   // propagation is complete.
172   void DeleteMarkedValues();
173 
174   // Constructs and initializes the InstructionValueSets of all instructions to
175   // contain exactly the HloValues defined by each instruction. These values can
176   // then propagated throughout the HLO graph by calling Propagate.
177   Status InitializeInstructionValueSets();
178 
179   // Updates the value set of the given instruction based on the values flowing
180   // into the instruction (operands and cross-computation dataflow).
181   bool UpdateInstructionValueSet(HloInstruction* instruction);
182 
183   // Updates the value set for a particular instruction type. Returns whether
184   // the instruction value set changed.
185   bool UpdateBitcastValueSet(HloInstruction* bitcast);
186   bool UpdateCallValueSet(HloInstruction* call);
187   bool UpdateConditionalValueSet(HloInstruction* conditional);
188   bool UpdateCopyValueSet(HloInstruction* copy);
189   bool UpdateDomainValueSet(HloInstruction* domain);
190   bool UpdateGetTupleElementValueSet(HloInstruction* gte);
191   bool UpdateParameterValueSet(HloInstruction* parameter);
192   bool UpdateCopyStartValueSet(HloInstruction* copy_start);
193   bool UpdateCopyDoneValueSet(HloInstruction* copy_done);
194   bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
195   bool UpdateTupleSelectValueSet(HloInstruction* select);
196   bool UpdateSendValueSet(HloInstruction* send);
197   bool UpdateSetDimensionSizeValueSet(HloInstruction* set_dimension_size);
198   bool UpdateTupleValueSet(HloInstruction* tuple);
199   bool UpdateWhileValueSet(HloInstruction* xla_while);
200   bool UpdateAddDependencyValueSet(HloInstruction* add_dependency);
201 
202   // Propagates the dataflow through the module. In particular, it propagates
203   // the HloValueSet from its defining instruction to the users of the
204   // instructions.
205   void Propagate();
206 
207   // Returns the result of the SSA Phi function applied to the given inputs at
208   // the given instruction.
209   bool Phi(HloInstruction* instruction,
210            absl::Span<const InstructionValueSet* const> inputs);
211 
212   // Updates the positions of the HloValues in the output of the given
213   // instruction. This should be called after the instruction value set of
214   // 'instruction' has been changed. 'prev_value_set' must point to the previous
215   // state of the value set prior to the change. 'prev_value_set' may be null if
216   // this is the first time positions are being computed. The previous state is
217   // necessary to efficiently remove positions which have been eliminated due to
218   // changes in the instructions' InstructionValueSet.
219   void UpdatePositionsOfValuesAt(
220       HloInstruction* instruction, const InstructionValueSet& new_value_set,
221       const InstructionValueSet* prev_value_set = nullptr);
222 
223   // Verifies various invariants of the dataflow analysis.
224   Status Verify() const;
225 
226   const HloModule& module_;
227   const bool ssa_form_;
228   const bool bitcast_defines_value_;
229 
230   std::unique_ptr<CallGraph> call_graph_;
231 
232   // The map of all HloValues in the module. We pass around pointers to the
233   // mapped HloValues, so the underlying container must keep them valid despite
234   // mutations touching other map entries.
235   std::unordered_map<HloValue::Id, HloValue> values_;
236 
237   // A map from instruction to InstructionValueSet.
238   std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
239 
240   // Values marked for deletion during construction. We don't delete them
241   // immediately because references to them may remain in ValueSets temporarily
242   // during propagation. After construction, these values are deleted.
243   std::vector<HloValue::Id> value_ids_to_delete_;
244 
245   // A vector containing all HloValues sorted by HloValue::Id.
246   std::vector<HloValue*> values_vector_;
247 
248   // The Id to use for the next HloValue.
249   HloValue::Id next_value_id_ = 0;
250 
251   // Backend specific function that decides whether an instruction can share
252   // a buffer with its operand.
253   CanShareBuffer can_share_buffer_ = nullptr;
254 };
255 
256 }  // namespace xla
257 
258 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
259