• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
18 
19 #include <stddef.h>
20 
21 #include <string>
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/buffer_value.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/shape_tree.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace xla {
36 
37 // Abstraction which identifies a specific point in the XLA graph. An
38 // HloPosition specifies a ShapeIndex within the output of a specific
39 // instruction.
40 struct HloPosition {
41   HloInstruction* instruction;
42   ShapeIndex index;
43 
44   // Returns the shape at this position.
45   const Shape& shape() const;
46 
47   string ToString() const;
48 
49   bool operator==(const HloPosition& other) const {
50     return instruction == other.instruction && index == other.index;
51   }
52   bool operator!=(const HloPosition& other) const { return !(*this == other); }
53 
54   // Stable less-than operator using instruction id and index.
55   bool operator<(const HloPosition& other) const {
56     return instruction->unique_id() < other.instruction->unique_id() ||
57            (instruction->unique_id() == other.instruction->unique_id() &&
58             index < other.index);
59   }
60 };
61 
62 std::ostream& operator<<(std::ostream& out, const HloPosition& position);
63 
64 // Defines a single use of an HLO value.
65 struct HloUse {
66   // Instruction at which the value is used.
67   HloInstruction* instruction;
68 
69   // The operand number in which the value is appears.
70   int64 operand_number;
71 
72   // The shape index within the operand in which the value appears.
73   ShapeIndex operand_index;
74 
75   string ToString() const;
76 
77   bool operator==(const HloUse& other) const {
78     return instruction == other.instruction &&
79            operand_number == other.operand_number &&
80            operand_index == other.operand_index;
81   }
82 
83   bool operator!=(const HloUse& other) const { return !(*this == other); }
84 };
85 
86 std::ostream& operator<<(std::ostream& out, const HloUse& use);
87 
88 // HloDataflowAnalysis uses this subclass of BufferValue.
89 class HloValue : public BufferValue {
90  public:
91   // Predicate comparing HloValues by increasing id, useful for std::sort.
IdLessThan(const HloValue * a,const HloValue * b)92   static bool IdLessThan(const HloValue* a, const HloValue* b) {
93     return a->id() < b->id();
94   }
95 
96   // Predicate comparing HloValues by equal id, useful for std::unique.
IdEqual(const HloValue * a,const HloValue * b)97   static bool IdEqual(const HloValue* a, const HloValue* b) {
98     return a->id() == b->id();
99   }
100 
101   // Construct an HloValue defined by 'instruction' at shape index 'index'. If
102   // is_phi is true, then this value is a phi value, for example, at the
103   // parameter of a while body computation. Phi values are only used in the SSA
104   // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true).
105   HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index,
106            bool is_phi = false);
~HloValue()107   ~HloValue() override {}
108 
109   // Sets the positions in the module at which the HloValue appears. Updates
110   // uses. Should be called once and only once. The defining position should not
111   // be included in 'positions' as this is set at construction time.
112   void SetPositionsAndComputeUses(absl::Span<const HloPosition> positions);
113 
114   // Returns whether this value is a phi value.
is_phi()115   bool is_phi() const { return is_phi_; }
116 
117   // Return the position where this value is defined.
defining_position()118   const HloPosition& defining_position() const { return positions_[0]; }
119 
120   // Return the instruction which defines this HloValue.
defining_instruction()121   HloInstruction* defining_instruction() const {
122     return defining_position().instruction;
123   }
124 
instruction()125   HloInstruction* instruction() const override {
126     return defining_instruction();
127   }
128 
129   // Return the shape index at which this HloValue is defined in the output of
130   // its defining instruction.
defining_index()131   const ShapeIndex& defining_index() const { return defining_position().index; }
132 
index()133   const ShapeIndex& index() const override { return defining_index(); }
134 
135   // Return the shape of this HloValue.
shape()136   const Shape& shape() const override { return defining_position().shape(); }
137 
138   // Return all positions of the HloValue in the module.
positions()139   const std::vector<HloPosition>& positions() const { return positions_; }
140 
141   // Return all uses of the HloValue.
uses()142   const std::vector<HloUse>& uses() const { return uses_; }
143 
144   // Get whether this HloValue is live out of the module.
live_out_of_module()145   bool live_out_of_module() const { return live_out_of_module_; }
146 
147   bool operator==(const HloValue& other) const;
148   bool operator!=(const HloValue& other) const;
149 
150   // Return a single-line string representation of the value.
151   string ToShortString() const;
152 
153   string ToString(int indent) const;
154 
ToString()155   string ToString() const override { return ToString(0); }
156 
157  private:
158   // Whether this instruction is a phi value.
159   const bool is_phi_;
160 
161   // The set of positions of this HloValue. The first element is always the
162   // position of the definition.
163   std::vector<HloPosition> positions_;
164 
165   // The set of uses of this HloValue.
166   std::vector<HloUse> uses_;
167 
168   // Whether this value is live out of the HLO module.
169   bool live_out_of_module_ = false;
170 };
171 
172 std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value);
173 
174 // A class representing the possible set of HloValues at a particular point
175 // (shape index in the output of an instruction) in the XLA graph. This set
176 // contains the set of reaching HloValue definitions. For a simple array-shaped
177 // instruction like Add, the HloValueSet of the top-level of the instruction's
178 // output trivially contains only the HloValue defined by the instruction. For
179 // instructions which have non-trivial dataflow such as Tuple or Select, the
180 // HloValueSets of the instruction's output contains one or more HloValues
181 // defined by the instruction's operands or defined further up in the XLA graph.
182 class HloValueSet {
183  public:
184   HloValueSet() = default;
185 
HloValueSet(absl::Span<const HloValue * const> values)186   explicit HloValueSet(absl::Span<const HloValue* const> values)
187       : values_(values.begin(), values.end()) {
188     SortAndUniquifyValues();
189   }
190 
191   // Sets this value set to the union of the given value sets. Returns whether
192   // this value set changed.
193   bool AssignUnionOf(absl::Span<const HloValueSet* const> inputs);
194 
195   // Return the vector of HloValues in the set. Values in the vector are unique
196   // and stably sorted by value id.
values()197   const std::vector<const HloValue*>& values() const { return values_; }
198 
199   // Adds the value to the set.  Returns true iff the value was added and didn't
200   // already exist in the set.
201   bool AddValue(const HloValue* value);
202 
203   // Clear all values from the set.
Clear()204   void Clear() { values_.clear(); }
205 
206   // Return the unique HLO value in the set. CHECKs if the set does not contain
207   // exactly one value.
GetUniqueValue()208   const HloValue& GetUniqueValue() const {
209     CHECK_EQ(values_.size(), 1);
210     return *values_[0];
211   }
212 
213   bool operator==(const HloValueSet& other) const {
214     if (values_.size() != other.values_.size()) return false;
215     for (size_t i = 0; i < values_.size(); ++i) {
216       if (values_[i]->id() != other.values_[i]->id()) {
217         return false;
218       }
219     }
220     return true;
221   }
222   bool operator!=(const HloValueSet& other) const { return !(*this == other); }
223 
224   string ToString() const;
225 
226  private:
227   // Sorts value_ and removes duplicates. This should be called after adding any
228   // elements to values_.
229   void SortAndUniquifyValues();
230 
231   // HloValues sorted by HloValue::Id.
232   std::vector<const HloValue*> values_;
233 };
234 
235 std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value);
236 
237 // A class collecting the HloValues which might be contained in the output of
238 // an HLO instruction. For array-shaped instructions, an InstructionValueSet
239 // trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets
240 // hold multiple HloValueSets.
241 class InstructionValueSet : public ShapeTree<HloValueSet> {
242  public:
InstructionValueSet(const Shape & shape)243   InstructionValueSet(const Shape& shape) : ShapeTree<HloValueSet>(shape) {}
244 
245   // Sets this value set to the union of the given value sets. Returns whether
246   // this value set changed.
247   bool AssignUnionOf(absl::Span<const InstructionValueSet* const> inputs);
248 
249   // Returns true if any value sets for any subshape element is not a
250   // singleton.
251   bool IsAmbiguous() const;
252 
253   string ToString() const;
254 };
255 
256 std::ostream& operator<<(std::ostream& out,
257                          const InstructionValueSet& instruction_value_set);
258 
259 }  // namespace xla
260 
261 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
262