1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 18 19 #include <stddef.h> 20 21 #include <string> 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/service/buffer_value.h" 26 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 27 #include "tensorflow/compiler/xla/shape_tree.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace xla { 36 37 // Abstraction which identifies a specific point in the XLA graph. An 38 // HloPosition specifies a ShapeIndex within the output of a specific 39 // instruction. 40 struct HloPosition { 41 HloInstruction* instruction; 42 ShapeIndex index; 43 44 // Returns the shape at this position. 45 const Shape& shape() const; 46 47 string ToString() const; 48 49 bool operator==(const HloPosition& other) const { 50 return instruction == other.instruction && index == other.index; 51 } 52 bool operator!=(const HloPosition& other) const { return !(*this == other); } 53 54 // Stable less-than operator using instruction id and index. 55 bool operator<(const HloPosition& other) const { 56 return instruction->unique_id() < other.instruction->unique_id() || 57 (instruction->unique_id() == other.instruction->unique_id() && 58 index < other.index); 59 } 60 }; 61 62 std::ostream& operator<<(std::ostream& out, const HloPosition& position); 63 64 // Defines a single use of an HLO value. 65 struct HloUse { 66 // Instruction at which the value is used. 67 HloInstruction* instruction; 68 69 // The operand number in which the value is appears. 70 int64 operand_number; 71 72 // The shape index within the operand in which the value appears. 73 ShapeIndex operand_index; 74 75 string ToString() const; 76 77 bool operator==(const HloUse& other) const { 78 return instruction == other.instruction && 79 operand_number == other.operand_number && 80 operand_index == other.operand_index; 81 } 82 83 bool operator!=(const HloUse& other) const { return !(*this == other); } 84 }; 85 86 std::ostream& operator<<(std::ostream& out, const HloUse& use); 87 88 // HloDataflowAnalysis uses this subclass of BufferValue. 89 class HloValue : public BufferValue { 90 public: 91 // Predicate comparing HloValues by increasing id, useful for std::sort. IdLessThan(const HloValue * a,const HloValue * b)92 static bool IdLessThan(const HloValue* a, const HloValue* b) { 93 return a->id() < b->id(); 94 } 95 96 // Predicate comparing HloValues by equal id, useful for std::unique. IdEqual(const HloValue * a,const HloValue * b)97 static bool IdEqual(const HloValue* a, const HloValue* b) { 98 return a->id() == b->id(); 99 } 100 101 // Construct an HloValue defined by 'instruction' at shape index 'index'. If 102 // is_phi is true, then this value is a phi value, for example, at the 103 // parameter of a while body computation. Phi values are only used in the SSA 104 // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). 105 HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, 106 bool is_phi = false); ~HloValue()107 ~HloValue() override {} 108 109 // Sets the positions in the module at which the HloValue appears. Updates 110 // uses. Should be called once and only once. The defining position should not 111 // be included in 'positions' as this is set at construction time. 112 void SetPositionsAndComputeUses(absl::Span<const HloPosition> positions); 113 114 // Returns whether this value is a phi value. is_phi()115 bool is_phi() const { return is_phi_; } 116 117 // Return the position where this value is defined. defining_position()118 const HloPosition& defining_position() const { return positions_[0]; } 119 120 // Return the instruction which defines this HloValue. defining_instruction()121 HloInstruction* defining_instruction() const { 122 return defining_position().instruction; 123 } 124 instruction()125 HloInstruction* instruction() const override { 126 return defining_instruction(); 127 } 128 129 // Return the shape index at which this HloValue is defined in the output of 130 // its defining instruction. defining_index()131 const ShapeIndex& defining_index() const { return defining_position().index; } 132 index()133 const ShapeIndex& index() const override { return defining_index(); } 134 135 // Return the shape of this HloValue. shape()136 const Shape& shape() const override { return defining_position().shape(); } 137 138 // Return all positions of the HloValue in the module. positions()139 const std::vector<HloPosition>& positions() const { return positions_; } 140 141 // Return all uses of the HloValue. uses()142 const std::vector<HloUse>& uses() const { return uses_; } 143 144 // Get whether this HloValue is live out of the module. live_out_of_module()145 bool live_out_of_module() const { return live_out_of_module_; } 146 147 bool operator==(const HloValue& other) const; 148 bool operator!=(const HloValue& other) const; 149 150 // Return a single-line string representation of the value. 151 string ToShortString() const; 152 153 string ToString(int indent) const; 154 ToString()155 string ToString() const override { return ToString(0); } 156 157 private: 158 // Whether this instruction is a phi value. 159 const bool is_phi_; 160 161 // The set of positions of this HloValue. The first element is always the 162 // position of the definition. 163 std::vector<HloPosition> positions_; 164 165 // The set of uses of this HloValue. 166 std::vector<HloUse> uses_; 167 168 // Whether this value is live out of the HLO module. 169 bool live_out_of_module_ = false; 170 }; 171 172 std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); 173 174 // A class representing the possible set of HloValues at a particular point 175 // (shape index in the output of an instruction) in the XLA graph. This set 176 // contains the set of reaching HloValue definitions. For a simple array-shaped 177 // instruction like Add, the HloValueSet of the top-level of the instruction's 178 // output trivially contains only the HloValue defined by the instruction. For 179 // instructions which have non-trivial dataflow such as Tuple or Select, the 180 // HloValueSets of the instruction's output contains one or more HloValues 181 // defined by the instruction's operands or defined further up in the XLA graph. 182 class HloValueSet { 183 public: 184 HloValueSet() = default; 185 HloValueSet(absl::Span<const HloValue * const> values)186 explicit HloValueSet(absl::Span<const HloValue* const> values) 187 : values_(values.begin(), values.end()) { 188 SortAndUniquifyValues(); 189 } 190 191 // Sets this value set to the union of the given value sets. Returns whether 192 // this value set changed. 193 bool AssignUnionOf(absl::Span<const HloValueSet* const> inputs); 194 195 // Return the vector of HloValues in the set. Values in the vector are unique 196 // and stably sorted by value id. values()197 const std::vector<const HloValue*>& values() const { return values_; } 198 199 // Adds the value to the set. Returns true iff the value was added and didn't 200 // already exist in the set. 201 bool AddValue(const HloValue* value); 202 203 // Clear all values from the set. Clear()204 void Clear() { values_.clear(); } 205 206 // Return the unique HLO value in the set. CHECKs if the set does not contain 207 // exactly one value. GetUniqueValue()208 const HloValue& GetUniqueValue() const { 209 CHECK_EQ(values_.size(), 1); 210 return *values_[0]; 211 } 212 213 bool operator==(const HloValueSet& other) const { 214 if (values_.size() != other.values_.size()) return false; 215 for (size_t i = 0; i < values_.size(); ++i) { 216 if (values_[i]->id() != other.values_[i]->id()) { 217 return false; 218 } 219 } 220 return true; 221 } 222 bool operator!=(const HloValueSet& other) const { return !(*this == other); } 223 224 string ToString() const; 225 226 private: 227 // Sorts value_ and removes duplicates. This should be called after adding any 228 // elements to values_. 229 void SortAndUniquifyValues(); 230 231 // HloValues sorted by HloValue::Id. 232 std::vector<const HloValue*> values_; 233 }; 234 235 std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); 236 237 // A class collecting the HloValues which might be contained in the output of 238 // an HLO instruction. For array-shaped instructions, an InstructionValueSet 239 // trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets 240 // hold multiple HloValueSets. 241 class InstructionValueSet : public ShapeTree<HloValueSet> { 242 public: InstructionValueSet(const Shape & shape)243 InstructionValueSet(const Shape& shape) : ShapeTree<HloValueSet>(shape) {} 244 245 // Sets this value set to the union of the given value sets. Returns whether 246 // this value set changed. 247 bool AssignUnionOf(absl::Span<const InstructionValueSet* const> inputs); 248 249 // Returns true if any value sets for any subshape element is not a 250 // singleton. 251 bool IsAmbiguous() const; 252 253 string ToString() const; 254 }; 255 256 std::ostream& operator<<(std::ostream& out, 257 const InstructionValueSet& instruction_value_set); 258 259 } // namespace xla 260 261 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ 262