1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 18 19 #include <stddef.h> 20 21 #include <string> 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/service/buffer_value.h" 26 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 27 #include "tensorflow/compiler/xla/shape_tree.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace xla { 36 37 // Abstraction which identifies a specific point in the XLA graph. An 38 // HloPosition specifies a ShapeIndex within the output of a specific 39 // instruction. 40 struct HloPosition { 41 HloInstruction* instruction; 42 ShapeIndex index; 43 44 // Returns the shape at this position. 45 const Shape& shape() const; 46 47 string ToString() const; 48 49 bool operator==(const HloPosition& other) const { 50 return instruction == other.instruction && index == other.index; 51 } 52 bool operator!=(const HloPosition& other) const { return !(*this == other); } 53 54 // Stable less-than operator using instruction id and index. 55 bool operator<(const HloPosition& other) const { 56 return instruction->unique_id() < other.instruction->unique_id() || 57 (instruction->unique_id() == other.instruction->unique_id() && 58 index < other.index); 59 } 60 61 template <typename H> AbslHashValueHloPosition62 friend H AbslHashValue(H h, const HloPosition& pos) { 63 return H::combine(std::move(h), pos.instruction->Hash(), pos.index); 64 } 65 }; 66 67 std::ostream& operator<<(std::ostream& out, const HloPosition& position); 68 69 // Defines a single use of an HLO value. 70 struct HloUse { 71 // Instruction at which the value is used. 72 HloInstruction* instruction; 73 74 // The operand number in which the value is appears. 75 int64 operand_number; 76 77 // The shape index within the operand in which the value appears. 78 ShapeIndex operand_index; 79 80 string ToString() const; 81 82 bool operator==(const HloUse& other) const { 83 return instruction == other.instruction && 84 operand_number == other.operand_number && 85 operand_index == other.operand_index; 86 } 87 88 bool operator!=(const HloUse& other) const { return !(*this == other); } 89 90 template <typename H> AbslHashValueHloUse91 friend H AbslHashValue(H h, const HloUse& use) { 92 return H::combine(std::move(h), use.instruction, use.operand_index, 93 use.operand_number); 94 } 95 }; 96 97 std::ostream& operator<<(std::ostream& out, const HloUse& use); 98 99 // HloDataflowAnalysis uses this subclass of BufferValue. 100 class HloValue : public BufferValue { 101 public: 102 // Predicate comparing HloValues by increasing id, useful for std::sort. IdLessThan(const HloValue * a,const HloValue * b)103 static bool IdLessThan(const HloValue* a, const HloValue* b) { 104 return a->id() < b->id(); 105 } 106 107 // Predicate comparing HloValues by equal id, useful for std::unique. IdEqual(const HloValue * a,const HloValue * b)108 static bool IdEqual(const HloValue* a, const HloValue* b) { 109 return a->id() == b->id(); 110 } 111 112 // Construct an HloValue defined by 'instruction' at shape index 'index'. If 113 // is_phi is true, then this value is a phi value, for example, at the 114 // parameter of a while body computation. Phi values are only used in the SSA 115 // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). 116 HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, 117 bool is_phi = false); ~HloValue()118 ~HloValue() override {} 119 120 // Sets the positions in the module at which the HloValue appears. Updates 121 // uses. Should be called once and only once. The defining position should not 122 // be included in 'positions' as this is set at construction time. 123 void SetPositionsAndComputeUses(absl::Span<const HloPosition> positions); 124 125 // Returns whether this value is a phi value. is_phi()126 bool is_phi() const { return is_phi_; } 127 128 // Return the position where this value is defined. defining_position()129 const HloPosition& defining_position() const { return positions_[0]; } 130 131 // Return the instruction which defines this HloValue. defining_instruction()132 HloInstruction* defining_instruction() const { 133 return defining_position().instruction; 134 } 135 instruction()136 HloInstruction* instruction() const override { 137 return defining_instruction(); 138 } 139 140 // Return the shape index at which this HloValue is defined in the output of 141 // its defining instruction. defining_index()142 const ShapeIndex& defining_index() const { return defining_position().index; } 143 index()144 const ShapeIndex& index() const override { return defining_index(); } 145 146 // Return the shape of this HloValue. shape()147 const Shape& shape() const override { return defining_position().shape(); } 148 149 // Return all positions of the HloValue in the module. positions()150 const std::vector<HloPosition>& positions() const { return positions_; } 151 152 // Return all uses of the HloValue. uses()153 const std::vector<HloUse>& uses() const { return uses_; } 154 155 // Get whether this HloValue is live out of the module. live_out_of_module()156 bool live_out_of_module() const { return live_out_of_module_; } 157 158 bool operator==(const HloValue& other) const; 159 bool operator!=(const HloValue& other) const; 160 161 // Return a single-line string representation of the value. 162 string ToShortString() const; 163 164 string ToString(int indent) const; 165 ToString()166 string ToString() const override { return ToString(0); } 167 168 private: 169 // Whether this instruction is a phi value. 170 const bool is_phi_; 171 172 // The set of positions of this HloValue. The first element is always the 173 // position of the definition. 174 std::vector<HloPosition> positions_; 175 176 // The set of uses of this HloValue. 177 std::vector<HloUse> uses_; 178 179 // Whether this value is live out of the HLO module. 180 bool live_out_of_module_ = false; 181 }; 182 183 std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); 184 185 // A class representing the possible set of HloValues at a particular point 186 // (shape index in the output of an instruction) in the XLA graph. This set 187 // contains the set of reaching HloValue definitions. For a simple array-shaped 188 // instruction like Add, the HloValueSet of the top-level of the instruction's 189 // output trivially contains only the HloValue defined by the instruction. For 190 // instructions which have non-trivial dataflow such as Tuple or Select, the 191 // HloValueSets of the instruction's output contains one or more HloValues 192 // defined by the instruction's operands or defined further up in the XLA graph. 193 class HloValueSet { 194 public: 195 HloValueSet() = default; 196 HloValueSet(absl::Span<const HloValue * const> values)197 explicit HloValueSet(absl::Span<const HloValue* const> values) 198 : values_(values.begin(), values.end()) { 199 SortAndUniquifyValues(); 200 } 201 202 // Sets this value set to the union of the given value sets. Returns whether 203 // this value set changed. 204 bool AssignUnionOf(absl::Span<const HloValueSet* const> inputs); 205 206 // Return the vector of HloValues in the set. Values in the vector are unique 207 // and stably sorted by value id. values()208 const std::vector<const HloValue*>& values() const { return values_; } 209 210 // Adds the value to the set. Returns true iff the value was added and didn't 211 // already exist in the set. 212 bool AddValue(const HloValue* value); 213 214 // Clear all values from the set. Clear()215 void Clear() { values_.clear(); } 216 217 // Return the unique HLO value in the set. CHECKs if the set does not contain 218 // exactly one value. GetUniqueValue()219 const HloValue& GetUniqueValue() const { 220 CHECK_EQ(values_.size(), 1); 221 return *values_[0]; 222 } 223 224 bool operator==(const HloValueSet& other) const { 225 if (values_.size() != other.values_.size()) return false; 226 for (size_t i = 0; i < values_.size(); ++i) { 227 if (values_[i]->id() != other.values_[i]->id()) { 228 return false; 229 } 230 } 231 return true; 232 } 233 bool operator!=(const HloValueSet& other) const { return !(*this == other); } 234 235 string ToString() const; 236 237 private: 238 // Sorts value_ and removes duplicates. This should be called after adding any 239 // elements to values_. 240 void SortAndUniquifyValues(); 241 242 // HloValues sorted by HloValue::Id. 243 std::vector<const HloValue*> values_; 244 }; 245 246 std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); 247 248 // A class collecting the HloValues which might be contained in the output of 249 // an HLO instruction. For array-shaped instructions, an InstructionValueSet 250 // trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets 251 // hold multiple HloValueSets. 252 class InstructionValueSet : public ShapeTree<HloValueSet> { 253 public: InstructionValueSet(const Shape & shape)254 explicit InstructionValueSet(const Shape& shape) 255 : ShapeTree<HloValueSet>(shape) {} 256 257 // Sets this value set to the union of the given value sets. Returns whether 258 // this value set changed. 259 bool AssignUnionOf(absl::Span<const InstructionValueSet* const> inputs); 260 261 // Returns true if any value sets for any subshape element is not a 262 // singleton. 263 bool IsAmbiguous() const; 264 265 string ToString() const; 266 }; 267 268 std::ostream& operator<<(std::ostream& out, 269 const InstructionValueSet& instruction_value_set); 270 271 } // namespace xla 272 273 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 274