• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
18 
19 #include <stddef.h>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/service/buffer_value.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/shape_tree.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 // Abstraction which identifies a specific point in the XLA graph. An
37 // HloPosition specifies a ShapeIndex within the output of a specific
38 // instruction.
39 struct HloPosition {
40   HloInstruction* instruction;
41   ShapeIndex index;
42 
43   // Returns the shape at this position.
44   const Shape& shape() const;
45 
46   string ToString() const;
47 
48   bool operator==(const HloPosition& other) const {
49     return instruction == other.instruction && index == other.index;
50   }
51   bool operator!=(const HloPosition& other) const { return !(*this == other); }
52 
53   // Stable less-than operator using instruction id and index.
54   bool operator<(const HloPosition& other) const {
55     return instruction->unique_id() < other.instruction->unique_id() ||
56            (instruction->unique_id() == other.instruction->unique_id() &&
57             index < other.index);
58   }
59 };
60 
61 std::ostream& operator<<(std::ostream& out, const HloPosition& position);
62 
63 // Defines a single use of an HLO value.
64 struct HloUse {
65   // Instruction at which the value is used.
66   HloInstruction* instruction;
67 
68   // The operand number in which the value is appears.
69   int64 operand_number;
70 
71   // The shape index within the operand in which the value appears.
72   ShapeIndex operand_index;
73 
74   string ToString() const;
75 
76   bool operator==(const HloUse& other) const {
77     return instruction == other.instruction &&
78            operand_number == other.operand_number &&
79            operand_index == other.operand_index;
80   }
81 
82   bool operator!=(const HloUse& other) const { return !(*this == other); }
83 };
84 
85 std::ostream& operator<<(std::ostream& out, const HloUse& use);
86 
87 // HloDataflowAnalysis uses this subclass of BufferValue.
88 class HloValue : public BufferValue {
89  public:
90   // Predicate comparing HloValues by increasing id, useful for std::sort.
IdLessThan(const HloValue * a,const HloValue * b)91   static bool IdLessThan(const HloValue* a, const HloValue* b) {
92     return a->id() < b->id();
93   }
94 
95   // Predicate comparing HloValues by equal id, useful for std::unique.
IdEqual(const HloValue * a,const HloValue * b)96   static bool IdEqual(const HloValue* a, const HloValue* b) {
97     return a->id() == b->id();
98   }
99 
100   // Construct an HloValue defined by 'instruction' at shape index 'index'. If
101   // is_phi is true, then this value is a phi value, for example, at the
102   // parameter of a while body computation. Phi values are only used in the SSA
103   // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true).
104   HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index,
105            bool is_phi = false);
~HloValue()106   ~HloValue() override {}
107 
108   // Sets the positions in the module at which the HloValue appears. Updates
109   // uses. Should be called once and only once. The defining position should not
110   // be included in 'positions' as this is set at construction time.
111   void SetPositionsAndComputeUses(absl::Span<const HloPosition> positions);
112 
113   // Returns whether this value is a phi value.
is_phi()114   bool is_phi() const { return is_phi_; }
115 
116   // Return the position where this value is defined.
defining_position()117   const HloPosition& defining_position() const { return positions_[0]; }
118 
119   // Return the instruction which defines this HloValue.
defining_instruction()120   HloInstruction* defining_instruction() const {
121     return defining_position().instruction;
122   }
123 
instruction()124   HloInstruction* instruction() const override {
125     return defining_instruction();
126   }
127 
128   // Return the shape index at which this HloValue is defined in the output of
129   // its defining instruction.
defining_index()130   const ShapeIndex& defining_index() const { return defining_position().index; }
131 
index()132   const ShapeIndex& index() const override { return defining_index(); }
133 
134   // Return the shape of this HloValue.
shape()135   const Shape& shape() const override { return defining_position().shape(); }
136 
137   // Return all positions of the HloValue in the module.
positions()138   const std::vector<HloPosition>& positions() const { return positions_; }
139 
140   // Return all uses of the HloValue.
uses()141   const std::vector<HloUse>& uses() const { return uses_; }
142 
143   // Get whether this HloValue is live out of the module.
live_out_of_module()144   bool live_out_of_module() const { return live_out_of_module_; }
145 
146   bool operator==(const HloValue& other) const;
147   bool operator!=(const HloValue& other) const;
148 
149   // Return a single-line string representation of the value.
150   string ToShortString() const;
151 
152   string ToString(int indent) const;
153 
ToString()154   string ToString() const override { return ToString(0); }
155 
156  private:
157   // Whether this instruction is a phi value.
158   const bool is_phi_;
159 
160   // The set of positions of this HloValue. The first element is always the
161   // position of the definition.
162   std::vector<HloPosition> positions_;
163 
164   // The set of uses of this HloValue.
165   std::vector<HloUse> uses_;
166 
167   // Whether this value is live out of the HLO module.
168   bool live_out_of_module_ = false;
169 };
170 
171 std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value);
172 
173 // A class representing the possible set of HloValues at a particular point
174 // (shape index in the output of an instruction) in the XLA graph. This set
175 // contains the set of reaching HloValue definitions. For a simple array-shaped
176 // instruction like Add, the HloValueSet of the top-level of the instruction's
177 // output trivially contains only the HloValue defined by the instruction. For
178 // instructions which have non-trivial dataflow such as Tuple or Select, the
179 // HloValueSets of the instruction's output contains one or more HloValues
180 // defined by the instruction's operands or defined further up in the XLA graph.
181 class HloValueSet {
182  public:
183   HloValueSet() = default;
184 
HloValueSet(absl::Span<const HloValue * const> values)185   explicit HloValueSet(absl::Span<const HloValue* const> values)
186       : values_(values.begin(), values.end()) {
187     SortAndUniquifyValues();
188   }
189 
190   // Sets this value set to the union of the given value sets. Returns whether
191   // this value set changed.
192   bool AssignUnionOf(absl::Span<const HloValueSet* const> inputs);
193 
194   // Return the vector of HloValues in the set. Values in the vector are unique
195   // and stably sorted by value id.
values()196   const std::vector<const HloValue*>& values() const { return values_; }
197 
198   // Adds the value to the set.  Returns true iff the value was added and didn't
199   // already exist in the set.
200   bool AddValue(const HloValue* value);
201 
202   // Clear all values from the set.
Clear()203   void Clear() { values_.clear(); }
204 
205   // Return the unique HLO value in the set. CHECKs if the set does not contain
206   // exactly one value.
GetUniqueValue()207   const HloValue& GetUniqueValue() const {
208     CHECK_EQ(values_.size(), 1);
209     return *values_[0];
210   }
211 
212   bool operator==(const HloValueSet& other) const {
213     if (values_.size() != other.values_.size()) return false;
214     for (size_t i = 0; i < values_.size(); ++i) {
215       if (values_[i]->id() != other.values_[i]->id()) {
216         return false;
217       }
218     }
219     return true;
220   }
221   bool operator!=(const HloValueSet& other) const { return !(*this == other); }
222 
223   string ToString() const;
224 
225  private:
226   // Sorts value_ and removes duplicates. This should be called after adding any
227   // elements to values_.
228   void SortAndUniquifyValues();
229 
230   // HloValues sorted by HloValue::Id.
231   std::vector<const HloValue*> values_;
232 };
233 
234 std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value);
235 
236 // A class collecting the HloValues which might be contained in the output of
237 // an HLO instruction. For array-shaped instructions, an InstructionValueSet
238 // trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets
239 // hold multiple HloValueSets.
240 class InstructionValueSet : public ShapeTree<HloValueSet> {
241  public:
InstructionValueSet(const Shape & shape)242   InstructionValueSet(const Shape& shape) : ShapeTree<HloValueSet>(shape) {}
243 
244   // Sets this value set to the union of the given value sets. Returns whether
245   // this value set changed.
246   bool AssignUnionOf(absl::Span<const InstructionValueSet* const> inputs);
247 
248   string ToString() const;
249 };
250 
251 std::ostream& operator<<(std::ostream& out,
252                          const InstructionValueSet& instruction_value_set);
253 
254 }  // namespace xla
255 
256 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
257