1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 18 19 #include <stddef.h> 20 #include <string> 21 #include <vector> 22 23 #include "absl/types/span.h" 24 #include "tensorflow/compiler/xla/service/buffer_value.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/shape_tree.h" 27 #include "tensorflow/compiler/xla/shape_util.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 36 // Abstraction which identifies a specific point in the XLA graph. An 37 // HloPosition specifies a ShapeIndex within the output of a specific 38 // instruction. 39 struct HloPosition { 40 HloInstruction* instruction; 41 ShapeIndex index; 42 43 // Returns the shape at this position. 44 const Shape& shape() const; 45 46 string ToString() const; 47 48 bool operator==(const HloPosition& other) const { 49 return instruction == other.instruction && index == other.index; 50 } 51 bool operator!=(const HloPosition& other) const { return !(*this == other); } 52 53 // Stable less-than operator using instruction id and index. 54 bool operator<(const HloPosition& other) const { 55 return instruction->unique_id() < other.instruction->unique_id() || 56 (instruction->unique_id() == other.instruction->unique_id() && 57 index < other.index); 58 } 59 }; 60 61 std::ostream& operator<<(std::ostream& out, const HloPosition& position); 62 63 // Defines a single use of an HLO value. 64 struct HloUse { 65 // Instruction at which the value is used. 66 HloInstruction* instruction; 67 68 // The operand number in which the value is appears. 69 int64 operand_number; 70 71 // The shape index within the operand in which the value appears. 72 ShapeIndex operand_index; 73 74 string ToString() const; 75 76 bool operator==(const HloUse& other) const { 77 return instruction == other.instruction && 78 operand_number == other.operand_number && 79 operand_index == other.operand_index; 80 } 81 82 bool operator!=(const HloUse& other) const { return !(*this == other); } 83 }; 84 85 std::ostream& operator<<(std::ostream& out, const HloUse& use); 86 87 // HloDataflowAnalysis uses this subclass of BufferValue. 88 class HloValue : public BufferValue { 89 public: 90 // Predicate comparing HloValues by increasing id, useful for std::sort. IdLessThan(const HloValue * a,const HloValue * b)91 static bool IdLessThan(const HloValue* a, const HloValue* b) { 92 return a->id() < b->id(); 93 } 94 95 // Predicate comparing HloValues by equal id, useful for std::unique. IdEqual(const HloValue * a,const HloValue * b)96 static bool IdEqual(const HloValue* a, const HloValue* b) { 97 return a->id() == b->id(); 98 } 99 100 // Construct an HloValue defined by 'instruction' at shape index 'index'. If 101 // is_phi is true, then this value is a phi value, for example, at the 102 // parameter of a while body computation. Phi values are only used in the SSA 103 // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). 104 HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, 105 bool is_phi = false); ~HloValue()106 ~HloValue() override {} 107 108 // Sets the positions in the module at which the HloValue appears. Updates 109 // uses. Should be called once and only once. The defining position should not 110 // be included in 'positions' as this is set at construction time. 111 void SetPositionsAndComputeUses(absl::Span<const HloPosition> positions); 112 113 // Returns whether this value is a phi value. is_phi()114 bool is_phi() const { return is_phi_; } 115 116 // Return the position where this value is defined. defining_position()117 const HloPosition& defining_position() const { return positions_[0]; } 118 119 // Return the instruction which defines this HloValue. defining_instruction()120 HloInstruction* defining_instruction() const { 121 return defining_position().instruction; 122 } 123 instruction()124 HloInstruction* instruction() const override { 125 return defining_instruction(); 126 } 127 128 // Return the shape index at which this HloValue is defined in the output of 129 // its defining instruction. defining_index()130 const ShapeIndex& defining_index() const { return defining_position().index; } 131 index()132 const ShapeIndex& index() const override { return defining_index(); } 133 134 // Return the shape of this HloValue. shape()135 const Shape& shape() const override { return defining_position().shape(); } 136 137 // Return all positions of the HloValue in the module. positions()138 const std::vector<HloPosition>& positions() const { return positions_; } 139 140 // Return all uses of the HloValue. uses()141 const std::vector<HloUse>& uses() const { return uses_; } 142 143 // Get whether this HloValue is live out of the module. live_out_of_module()144 bool live_out_of_module() const { return live_out_of_module_; } 145 146 bool operator==(const HloValue& other) const; 147 bool operator!=(const HloValue& other) const; 148 149 // Return a single-line string representation of the value. 150 string ToShortString() const; 151 152 string ToString(int indent) const; 153 ToString()154 string ToString() const override { return ToString(0); } 155 156 private: 157 // Whether this instruction is a phi value. 158 const bool is_phi_; 159 160 // The set of positions of this HloValue. The first element is always the 161 // position of the definition. 162 std::vector<HloPosition> positions_; 163 164 // The set of uses of this HloValue. 165 std::vector<HloUse> uses_; 166 167 // Whether this value is live out of the HLO module. 168 bool live_out_of_module_ = false; 169 }; 170 171 std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); 172 173 // A class representing the possible set of HloValues at a particular point 174 // (shape index in the output of an instruction) in the XLA graph. This set 175 // contains the set of reaching HloValue definitions. For a simple array-shaped 176 // instruction like Add, the HloValueSet of the top-level of the instruction's 177 // output trivially contains only the HloValue defined by the instruction. For 178 // instructions which have non-trivial dataflow such as Tuple or Select, the 179 // HloValueSets of the instruction's output contains one or more HloValues 180 // defined by the instruction's operands or defined further up in the XLA graph. 181 class HloValueSet { 182 public: 183 HloValueSet() = default; 184 HloValueSet(absl::Span<const HloValue * const> values)185 explicit HloValueSet(absl::Span<const HloValue* const> values) 186 : values_(values.begin(), values.end()) { 187 SortAndUniquifyValues(); 188 } 189 190 // Sets this value set to the union of the given value sets. Returns whether 191 // this value set changed. 192 bool AssignUnionOf(absl::Span<const HloValueSet* const> inputs); 193 194 // Return the vector of HloValues in the set. Values in the vector are unique 195 // and stably sorted by value id. values()196 const std::vector<const HloValue*>& values() const { return values_; } 197 198 // Adds the value to the set. Returns true iff the value was added and didn't 199 // already exist in the set. 200 bool AddValue(const HloValue* value); 201 202 // Clear all values from the set. Clear()203 void Clear() { values_.clear(); } 204 205 // Return the unique HLO value in the set. CHECKs if the set does not contain 206 // exactly one value. GetUniqueValue()207 const HloValue& GetUniqueValue() const { 208 CHECK_EQ(values_.size(), 1); 209 return *values_[0]; 210 } 211 212 bool operator==(const HloValueSet& other) const { 213 if (values_.size() != other.values_.size()) return false; 214 for (size_t i = 0; i < values_.size(); ++i) { 215 if (values_[i]->id() != other.values_[i]->id()) { 216 return false; 217 } 218 } 219 return true; 220 } 221 bool operator!=(const HloValueSet& other) const { return !(*this == other); } 222 223 string ToString() const; 224 225 private: 226 // Sorts value_ and removes duplicates. This should be called after adding any 227 // elements to values_. 228 void SortAndUniquifyValues(); 229 230 // HloValues sorted by HloValue::Id. 231 std::vector<const HloValue*> values_; 232 }; 233 234 std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); 235 236 // A class collecting the HloValues which might be contained in the output of 237 // an HLO instruction. For array-shaped instructions, an InstructionValueSet 238 // trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets 239 // hold multiple HloValueSets. 240 class InstructionValueSet : public ShapeTree<HloValueSet> { 241 public: InstructionValueSet(const Shape & shape)242 InstructionValueSet(const Shape& shape) : ShapeTree<HloValueSet>(shape) {} 243 244 // Sets this value set to the union of the given value sets. Returns whether 245 // this value set changed. 246 bool AssignUnionOf(absl::Span<const InstructionValueSet* const> inputs); 247 248 string ToString() const; 249 }; 250 251 std::ostream& operator<<(std::ostream& out, 252 const InstructionValueSet& instruction_value_set); 253 254 } // namespace xla 255 256 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 257