• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
18 
19 #include <stddef.h>
20 
21 #include <string>
22 #include <tuple>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/types/span.h"
29 #include "tensorflow/compiler/xla/service/buffer_value.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/shape_tree.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace xla {
38 
39 // Abstraction which identifies a specific point in the XLA graph. An
40 // HloPosition specifies a ShapeIndex within the output of a specific
41 // instruction.
42 struct HloPosition {
43   HloInstruction* instruction;
44   ShapeIndex index;
45 
46   // Returns the shape at this position.
47   const Shape& shape() const;
48 
49   std::string ToString() const;
50 
51   bool operator==(const HloPosition& other) const {
52     return instruction == other.instruction && index == other.index;
53   }
54   bool operator!=(const HloPosition& other) const { return !(*this == other); }
55 
56   // Sort by instruction ID, then index.
57   bool operator<(const HloPosition& other) const {
58     return std::forward_as_tuple(instruction->unique_id(), index) <
59            std::forward_as_tuple(other.instruction->unique_id(), other.index);
60   }
61 
62   template <typename H>
AbslHashValueHloPosition63   friend H AbslHashValue(H h, const HloPosition& pos) {
64     return H::combine(std::move(h), *pos.instruction, pos.index);
65   }
66 };
67 
68 std::ostream& operator<<(std::ostream& out, const HloPosition& position);
69 
70 // Defines a single use of an HLO value.
71 struct HloUse {
72   // Instruction at which the value is used.
73   HloInstruction* instruction;
74 
75   // The operand number in which the value is appears.
76   int64_t operand_number;
77 
78   // The shape index within the operand in which the value appears.
79   ShapeIndex operand_index;
80 
81   std::string ToString() const;
82 
83   bool operator==(const HloUse& other) const {
84     return instruction == other.instruction &&
85            operand_number == other.operand_number &&
86            operand_index == other.operand_index;
87   }
88 
89   bool operator!=(const HloUse& other) const { return !(*this == other); }
90 
91   template <typename H>
AbslHashValueHloUse92   friend H AbslHashValue(H h, const HloUse& use) {
93     return H::combine(std::move(h), use.instruction, use.operand_index,
94                       use.operand_number);
95   }
96 };
97 
98 std::ostream& operator<<(std::ostream& out, const HloUse& use);
99 
100 // HloDataflowAnalysis uses this subclass of BufferValue.
101 class HloValue : public BufferValue {
102  public:
103   // Predicate comparing HloValues by increasing id, useful for std::sort.
IdLessThan(const HloValue * a,const HloValue * b)104   static bool IdLessThan(const HloValue* a, const HloValue* b) {
105     return a->id() < b->id();
106   }
107 
108   // Construct an HloValue defined by 'instruction' at shape index 'index'. If
109   // is_phi is true, then this value is a phi value, for example, at the
110   // parameter of a while body computation. Phi values are only used in the SSA
111   // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true).
112   HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index,
113            bool is_phi = false);
114 
115   // Sets the positions in the module at which the HloValue appears. Should be
116   // called once and only once. The defining position should not be included in
117   // 'positions' as this is set at construction time.
118   void SetPositions(absl::Span<const HloPosition> positions);
119 
120   // Returns whether this value is a phi value.
is_phi()121   bool is_phi() const { return is_phi_; }
122 
123   // Return the position where this value is defined.
defining_position()124   const HloPosition& defining_position() const { return positions_[0]; }
125 
126   // Return the instruction which defines this HloValue.
defining_instruction()127   HloInstruction* defining_instruction() const {
128     return defining_position().instruction;
129   }
130 
instruction()131   HloInstruction* instruction() const override {
132     return defining_instruction();
133   }
134 
135   // Return the shape index at which this HloValue is defined in the output of
136   // its defining instruction.
defining_index()137   const ShapeIndex& defining_index() const { return defining_position().index; }
138 
index()139   const ShapeIndex& index() const override { return defining_index(); }
140 
141   // Return the shape of this HloValue.
shape()142   const Shape& shape() const override { return defining_position().shape(); }
143 
144   // Return all positions of the HloValue in the module.
positions()145   const std::vector<HloPosition>& positions() const { return positions_; }
146 
147   // Return all uses of the HloValue. This computes the uses lazily, and the
148   // overhead could be non-trivial for the first invocation. Therefore even
149   // though it is marked `const`, it actually can mutate its data members. It is
150   // kept this way to allow passing around const references.
GetUses()151   absl::Span<const HloUse> GetUses() const {
152     return uses_.MaybeInitAndGet(
153         [this](std::vector<HloUse>& uses) { ComputeUses(uses); });
154   }
155 
156   // Returns true if this has a position that is the root of the given
157   // computation.
158   bool IsRootOf(const HloComputation* computation) const;
159 
160   // Get whether this HloValue is live out of the module.
live_out_of_module()161   bool live_out_of_module() const { return live_out_of_module_; }
162 
163   bool operator==(const HloValue& other) const { return this == &other; }
164   bool operator!=(const HloValue& other) const { return !(*this == other); }
165 
166   // Return a single-line string representation of the value.
167   std::string ToShortString() const;
168   std::string ToString(int indent) const;
ToString()169   std::string ToString() const override { return ToString(0); }
170 
171  private:
172   template <typename T>
173   class Lazy {
174    public:
175     Lazy() = default;
MaybeInitAndGet(absl::FunctionRef<void (T &)> func)176     const T& MaybeInitAndGet(absl::FunctionRef<void(T&)> func) const {
177       if (!initialized_) {
178         func(uses_);
179         initialized_ = true;
180       }
181       return uses_;
182     }
183 
184    private:
185     mutable T uses_;
186     mutable bool initialized_ = false;
187   };
188   // Called when lazily computing the uses.
189   void ComputeUses(std::vector<HloUse>& uses) const;
190 
191   // The set of positions of this HloValue. The first element is always the
192   // position of the definition.
193   std::vector<HloPosition> positions_;
194 
195   // The set of uses of this HloValue. This is lazily constructed until getting
196   // accessed.
197   Lazy<std::vector<HloUse>> uses_;
198 
199   // Whether this instruction is a phi value.
200   const bool is_phi_;
201 
202   // Whether this value is live out of the HLO module.
203   bool live_out_of_module_ = false;
204 };
205 
206 std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value);
207 
208 // A class representing the possible set of HloValues at a particular point
209 // (shape index in the output of an instruction) in the XLA graph. This set
210 // contains the set of reaching HloValue definitions. For a simple array-shaped
211 // instruction like Add, the HloValueSet of the top-level of the instruction's
212 // output trivially contains only the HloValue defined by the instruction. For
213 // instructions which have non-trivial dataflow such as Tuple or Select, the
214 // HloValueSets of the instruction's output contains one or more HloValues
215 // defined by the instruction's operands or defined further up in the XLA graph.
216 class HloValueSet {
217  public:
218   HloValueSet() = default;
219 
220   explicit HloValueSet(absl::Span<const HloValue* const> values);
221   explicit HloValueSet(const absl::flat_hash_set<const HloValue*>& values);
222 
223   // Sets this value set to the union of the given value sets. Returns whether
224   // this value set changed.
225   bool AssignUnionOf(absl::Span<const HloValueSet* const> inputs);
226 
227   // Return the vector of HloValues in the set. Values in the vector are unique
228   // and stably sorted by value id.
values()229   const std::vector<const HloValue*>& values() const { return values_; }
230 
231   // Adds the value to the set.  Returns true iff the value was added and didn't
232   // already exist in the set.
233   bool AddValue(const HloValue* value);
234 
235   // Clear all values from the set.
Clear()236   void Clear() { values_.clear(); }
237 
TakeValues()238   std::vector<const HloValue*> TakeValues() { return std::move(values_); }
239 
240   // Return the unique HLO value in the set. CHECKs if the set does not contain
241   // exactly one value.
GetUniqueValue()242   const HloValue& GetUniqueValue() const {
243     CHECK_EQ(values_.size(), 1);
244     return *values_[0];
245   }
246 
247   bool operator==(const HloValueSet& other) const {
248     if (values_.size() != other.values_.size()) return false;
249     for (size_t i = 0; i < values_.size(); ++i) {
250       if (values_[i]->id() != other.values_[i]->id()) {
251         return false;
252       }
253     }
254     return true;
255   }
256   bool operator!=(const HloValueSet& other) const { return !(*this == other); }
257 
258   std::string ToString() const;
259 
260  private:
261   // Sorts value_ and removes duplicates. This should be called after adding any
262   // elements to values_.
263   void SortAndUniquifyValues();
264 
265   // HloValues sorted by HloValue::Id.
266   std::vector<const HloValue*> values_;
267 };
268 
269 std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value);
270 
271 // A class collecting the HloValues which might be contained in the output of
272 // an HLO instruction. For array-shaped instructions, an InstructionValueSet
273 // trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets
274 // hold multiple HloValueSets.
275 class InstructionValueSet : public ShapeTree<HloValueSet> {
276  public:
InstructionValueSet(const Shape & shape)277   explicit InstructionValueSet(const Shape& shape)
278       : ShapeTree<HloValueSet>(shape) {}
279 
280   // Sets this value set to the union of the given value sets. Returns whether
281   // this value set changed.
282   bool AssignUnionOf(absl::Span<const InstructionValueSet* const> inputs);
283 
284   // Sets this value set to the input value set at the given index. Returns
285   // whether this value set changed.
286   bool AssignUnionOf(const InstructionValueSet& input,
287                      ShapeIndexView input_index);
288 
289   // Returns true if any value sets for any subshape element is not a
290   // singleton.
291   bool IsAmbiguous() const;
292 
293   std::string ToString() const;
294 };
295 
296 std::ostream& operator<<(std::ostream& out,
297                          const InstructionValueSet& instruction_value_set);
298 
299 }  // namespace xla
300 
301 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
302