1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 18 19 #include <stddef.h> 20 21 #include <string> 22 #include <tuple> 23 #include <utility> 24 #include <vector> 25 26 #include "absl/algorithm/container.h" 27 #include "absl/container/flat_hash_set.h" 28 #include "absl/types/span.h" 29 #include "tensorflow/compiler/xla/service/buffer_value.h" 30 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 31 #include "tensorflow/compiler/xla/shape_tree.h" 32 #include "tensorflow/compiler/xla/shape_util.h" 33 #include "tensorflow/compiler/xla/types.h" 34 #include "tensorflow/compiler/xla/xla_data.pb.h" 35 #include "tensorflow/core/platform/logging.h" 36 37 namespace xla { 38 39 // Abstraction which identifies a specific point in the XLA graph. An 40 // HloPosition specifies a ShapeIndex within the output of a specific 41 // instruction. 42 struct HloPosition { 43 HloInstruction* instruction; 44 ShapeIndex index; 45 46 // Returns the shape at this position. 47 const Shape& shape() const; 48 49 std::string ToString() const; 50 51 bool operator==(const HloPosition& other) const { 52 return instruction == other.instruction && index == other.index; 53 } 54 bool operator!=(const HloPosition& other) const { return !(*this == other); } 55 56 // Sort by instruction ID, then index. 57 bool operator<(const HloPosition& other) const { 58 return std::forward_as_tuple(instruction->unique_id(), index) < 59 std::forward_as_tuple(other.instruction->unique_id(), other.index); 60 } 61 62 template <typename H> AbslHashValueHloPosition63 friend H AbslHashValue(H h, const HloPosition& pos) { 64 return H::combine(std::move(h), *pos.instruction, pos.index); 65 } 66 }; 67 68 std::ostream& operator<<(std::ostream& out, const HloPosition& position); 69 70 // Defines a single use of an HLO value. 71 struct HloUse { 72 // Instruction at which the value is used. 73 HloInstruction* instruction; 74 75 // The operand number in which the value is appears. 76 int64_t operand_number; 77 78 // The shape index within the operand in which the value appears. 79 ShapeIndex operand_index; 80 81 std::string ToString() const; 82 83 bool operator==(const HloUse& other) const { 84 return instruction == other.instruction && 85 operand_number == other.operand_number && 86 operand_index == other.operand_index; 87 } 88 89 bool operator!=(const HloUse& other) const { return !(*this == other); } 90 91 template <typename H> AbslHashValueHloUse92 friend H AbslHashValue(H h, const HloUse& use) { 93 return H::combine(std::move(h), use.instruction, use.operand_index, 94 use.operand_number); 95 } 96 }; 97 98 std::ostream& operator<<(std::ostream& out, const HloUse& use); 99 100 // HloDataflowAnalysis uses this subclass of BufferValue. 101 class HloValue : public BufferValue { 102 public: 103 // Predicate comparing HloValues by increasing id, useful for std::sort. IdLessThan(const HloValue * a,const HloValue * b)104 static bool IdLessThan(const HloValue* a, const HloValue* b) { 105 return a->id() < b->id(); 106 } 107 108 // Construct an HloValue defined by 'instruction' at shape index 'index'. If 109 // is_phi is true, then this value is a phi value, for example, at the 110 // parameter of a while body computation. Phi values are only used in the SSA 111 // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). 112 HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, 113 bool is_phi = false); 114 115 // Sets the positions in the module at which the HloValue appears. Should be 116 // called once and only once. The defining position should not be included in 117 // 'positions' as this is set at construction time. 118 void SetPositions(absl::Span<const HloPosition> positions); 119 120 // Returns whether this value is a phi value. is_phi()121 bool is_phi() const { return is_phi_; } 122 123 // Return the position where this value is defined. defining_position()124 const HloPosition& defining_position() const { return positions_[0]; } 125 126 // Return the instruction which defines this HloValue. defining_instruction()127 HloInstruction* defining_instruction() const { 128 return defining_position().instruction; 129 } 130 instruction()131 HloInstruction* instruction() const override { 132 return defining_instruction(); 133 } 134 135 // Return the shape index at which this HloValue is defined in the output of 136 // its defining instruction. defining_index()137 const ShapeIndex& defining_index() const { return defining_position().index; } 138 index()139 const ShapeIndex& index() const override { return defining_index(); } 140 141 // Return the shape of this HloValue. shape()142 const Shape& shape() const override { return defining_position().shape(); } 143 144 // Return all positions of the HloValue in the module. positions()145 const std::vector<HloPosition>& positions() const { return positions_; } 146 147 // Return all uses of the HloValue. This computes the uses lazily, and the 148 // overhead could be non-trivial for the first invocation. Therefore even 149 // though it is marked `const`, it actually can mutate its data members. It is 150 // kept this way to allow passing around const references. GetUses()151 absl::Span<const HloUse> GetUses() const { 152 return uses_.MaybeInitAndGet( 153 [this](std::vector<HloUse>& uses) { ComputeUses(uses); }); 154 } 155 156 // Returns true if this has a position that is the root of the given 157 // computation. 158 bool IsRootOf(const HloComputation* computation) const; 159 160 // Get whether this HloValue is live out of the module. live_out_of_module()161 bool live_out_of_module() const { return live_out_of_module_; } 162 163 bool operator==(const HloValue& other) const { return this == &other; } 164 bool operator!=(const HloValue& other) const { return !(*this == other); } 165 166 // Return a single-line string representation of the value. 167 std::string ToShortString() const; 168 std::string ToString(int indent) const; ToString()169 std::string ToString() const override { return ToString(0); } 170 171 private: 172 template <typename T> 173 class Lazy { 174 public: 175 Lazy() = default; MaybeInitAndGet(absl::FunctionRef<void (T &)> func)176 const T& MaybeInitAndGet(absl::FunctionRef<void(T&)> func) const { 177 if (!initialized_) { 178 func(uses_); 179 initialized_ = true; 180 } 181 return uses_; 182 } 183 184 private: 185 mutable T uses_; 186 mutable bool initialized_ = false; 187 }; 188 // Called when lazily computing the uses. 189 void ComputeUses(std::vector<HloUse>& uses) const; 190 191 // The set of positions of this HloValue. The first element is always the 192 // position of the definition. 193 std::vector<HloPosition> positions_; 194 195 // The set of uses of this HloValue. This is lazily constructed until getting 196 // accessed. 197 Lazy<std::vector<HloUse>> uses_; 198 199 // Whether this instruction is a phi value. 200 const bool is_phi_; 201 202 // Whether this value is live out of the HLO module. 203 bool live_out_of_module_ = false; 204 }; 205 206 std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); 207 208 // A class representing the possible set of HloValues at a particular point 209 // (shape index in the output of an instruction) in the XLA graph. This set 210 // contains the set of reaching HloValue definitions. For a simple array-shaped 211 // instruction like Add, the HloValueSet of the top-level of the instruction's 212 // output trivially contains only the HloValue defined by the instruction. For 213 // instructions which have non-trivial dataflow such as Tuple or Select, the 214 // HloValueSets of the instruction's output contains one or more HloValues 215 // defined by the instruction's operands or defined further up in the XLA graph. 216 class HloValueSet { 217 public: 218 HloValueSet() = default; 219 220 explicit HloValueSet(absl::Span<const HloValue* const> values); 221 explicit HloValueSet(const absl::flat_hash_set<const HloValue*>& values); 222 223 // Sets this value set to the union of the given value sets. Returns whether 224 // this value set changed. 225 bool AssignUnionOf(absl::Span<const HloValueSet* const> inputs); 226 227 // Return the vector of HloValues in the set. Values in the vector are unique 228 // and stably sorted by value id. values()229 const std::vector<const HloValue*>& values() const { return values_; } 230 231 // Adds the value to the set. Returns true iff the value was added and didn't 232 // already exist in the set. 233 bool AddValue(const HloValue* value); 234 235 // Clear all values from the set. Clear()236 void Clear() { values_.clear(); } 237 TakeValues()238 std::vector<const HloValue*> TakeValues() { return std::move(values_); } 239 240 // Return the unique HLO value in the set. CHECKs if the set does not contain 241 // exactly one value. GetUniqueValue()242 const HloValue& GetUniqueValue() const { 243 CHECK_EQ(values_.size(), 1); 244 return *values_[0]; 245 } 246 247 bool operator==(const HloValueSet& other) const { 248 if (values_.size() != other.values_.size()) return false; 249 for (size_t i = 0; i < values_.size(); ++i) { 250 if (values_[i]->id() != other.values_[i]->id()) { 251 return false; 252 } 253 } 254 return true; 255 } 256 bool operator!=(const HloValueSet& other) const { return !(*this == other); } 257 258 std::string ToString() const; 259 260 private: 261 // Sorts value_ and removes duplicates. This should be called after adding any 262 // elements to values_. 263 void SortAndUniquifyValues(); 264 265 // HloValues sorted by HloValue::Id. 266 std::vector<const HloValue*> values_; 267 }; 268 269 std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); 270 271 // A class collecting the HloValues which might be contained in the output of 272 // an HLO instruction. For array-shaped instructions, an InstructionValueSet 273 // trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets 274 // hold multiple HloValueSets. 275 class InstructionValueSet : public ShapeTree<HloValueSet> { 276 public: InstructionValueSet(const Shape & shape)277 explicit InstructionValueSet(const Shape& shape) 278 : ShapeTree<HloValueSet>(shape) {} 279 280 // Sets this value set to the union of the given value sets. Returns whether 281 // this value set changed. 282 bool AssignUnionOf(absl::Span<const InstructionValueSet* const> inputs); 283 284 // Sets this value set to the input value set at the given index. Returns 285 // whether this value set changed. 286 bool AssignUnionOf(const InstructionValueSet& input, 287 ShapeIndexView input_index); 288 289 // Returns true if any value sets for any subshape element is not a 290 // singleton. 291 bool IsAmbiguous() const; 292 293 std::string ToString() const; 294 }; 295 296 std::ostream& operator<<(std::ostream& out, 297 const InstructionValueSet& instruction_value_set); 298 299 } // namespace xla 300 301 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 302