1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 18 19 #include <cstdio> 20 #include <list> 21 #include <vector> 22 23 #include "absl/base/casts.h" 24 #include "absl/container/flat_hash_map.h" 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/map_util.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 36 // A class for representing reachability between HloInstructions. 37 // 38 // It has an adjacency matrix and it is up to the user of the class to set the 39 // adjacency matrix such that it represents reachability, i.e. such that it is 40 // transitive. That the graph be transitive is thus not an invariant of this 41 // class, but it is required for the name of the class and its methods to make 42 // sense. 43 class HloReachabilityMap { 44 public: 45 // Sets up a graph with no edges and where the nodes correspond to the given 46 // instructions. 47 explicit HloReachabilityMap( 48 absl::Span<const HloInstruction* const> instructions); 49 50 // Computes and returns the reachability between HLO instructions in the 51 // computation. The returned HloReachabilityMap is constructed such that 52 // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a 53 // directed path (from producer to consumer) from 'a' to 'b'. Both data 54 // dependencies (operands) and control dependencies are considered for 55 // reachability. Trivially an instruction is reachable from itself. 56 static std::unique_ptr<HloReachabilityMap> Build( 57 const HloComputation* computation); 58 59 // Similar to the above Build operation except that it tries to identify 60 // paths between instructions that do not contain control instructions 61 // and multiple operands, i.e., b is_reachable a == true iff 62 // b = f(f(f(f(f(a), constant), constant), constant). 63 // Further, the only ops allowed in a path are basic math operations such 64 // as add, sub, mul, div. 65 static std::unique_ptr<HloReachabilityMap> BuildWithRestrictions( 66 const HloComputation* computation, 67 absl::FunctionRef<void(const HloInstruction*, 68 std::vector<HloInstruction*>*)> 69 add_dependencies); 70 71 // Set the reachability set of 'instruction' to the union of the reachability 72 // sets of 'inputs'. Upon return, IsReachable(x, instruction) where 73 // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true 74 // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from 75 // itself. Returns whether the reachability set of 'instruction' changed. 76 // 77 // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency 78 // vector in the internal graph of this HloReachabilityMap for the given 79 // instruction and does not transitively update any other part of the 80 // adjacency matrix. 81 bool SetReachabilityToUnion(absl::Span<const HloInstruction* const> inputs, 82 const HloInstruction* instruction); 83 84 // As above, but faster because it does not check if the reachability changed. 85 void FastSetReachabilityToUnion( 86 absl::Span<const HloInstruction* const> inputs, 87 const HloInstruction* instruction); 88 89 // An opaque index that clients can use to make repeated operations for the 90 // same instruction faster, by calling GetIndex once for the instruction, 91 // and then calling the variants of other interfaces that take Index arguments 92 // rather than HloInstruction* arguments. 93 struct Index { 94 private: 95 friend class HloReachabilityMap; 96 97 // Index assigned for a particular instruction. The value is used to index 98 // into the vector of BitVectors and the BitVectors themselves. 99 int v; 100 }; GetIndex(const HloInstruction * instruction)101 Index GetIndex(const HloInstruction* instruction) const { 102 Index i; 103 i.v = FindOrDie(indices_, GetKey(instruction)); 104 return i; 105 } 106 107 // Sets entry so that IsReachable(a, b) will return true 108 // 109 // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency 110 // matrix in the internal graph of this HloReachabilityMap to have an edge 111 // from a to b and does not transitively update any other part of the 112 // adjacency matrix. SetReachable(const HloInstruction * a,const HloInstruction * b)113 void SetReachable(const HloInstruction* a, const HloInstruction* b) { 114 SetReachable(GetIndex(a), GetIndex(b)); 115 } 116 void SetReachable(Index a, Index b); 117 118 // Updates the given reachability map after the immediate predecessor set 119 // (operands and control predecessors) of 'instruction' has changed. 120 void UpdateReachabilityThroughInstruction(const HloInstruction* instruction); 121 122 // Returns true if "b" is reachable from "a" 123 // 124 // Note that this function only correctly answers queries about reachability 125 // if the set of edges that have been provided to this class are transitive. IsReachable(const HloInstruction * a,const HloInstruction * b)126 bool IsReachable(const HloInstruction* a, const HloInstruction* b) const { 127 return IsReachable(GetIndex(a), GetIndex(b)); 128 } IsReachable(Index a,Index b)129 bool IsReachable(Index a, Index b) const { return GetBitVector(b).Get(a.v); } 130 131 // Returns true if "b" is reachable from "a" or "a" is reachable from "b" 132 // 133 // Note that this function only correctly answers queries about reachability 134 // if the set of edges that have been provided to this class are transitive. IsConnected(const HloInstruction * a,const HloInstruction * b)135 bool IsConnected(const HloInstruction* a, const HloInstruction* b) const { 136 return IsConnected(GetIndex(a), GetIndex(b)); 137 } IsConnected(Index a,Index b)138 bool IsConnected(Index a, Index b) const { 139 return IsReachable(a, b) || IsReachable(b, a); 140 } 141 142 // Checks if an instruction is in the Reachability map. IsPresent(const HloInstruction * a)143 bool IsPresent(const HloInstruction* a) const { 144 return indices_.contains(GetKey(a)); 145 } 146 147 // Replace the instruction "original" with "replacement" in the reachability 148 // map. 149 void Replace(const HloInstruction* original, 150 const HloInstruction* replacement); 151 152 private: 153 // A bit-vector implementation specialized for this use case which provides a 154 // fast bitwise OR operation not available in tensorflow::gtl::BitMap. 155 class BitVector { 156 public: 157 BitVector() = default; BitVector(size_t size)158 BitVector(size_t size) 159 : size_(size), vector_((size + kBits - 1) / kBits, 0) {} 160 161 // Return the bit at the given index. Get(size_t index)162 bool Get(size_t index) const { 163 DCHECK(index >= 0 && index < size_); 164 return vector_[index / kBits] & (1ull << (index % kBits)); 165 } 166 167 // Set the bit at the given index. Set(size_t index)168 void Set(size_t index) { 169 DCHECK(index >= 0 && index < size_); 170 vector_[index / kBits] |= 1ull << (index % kBits); 171 } 172 173 // Set this bitvector to the Logical OR of this bitvector and 'other'. OrWith(const BitVector & other)174 void OrWith(const BitVector& other) { 175 for (size_t i = 0; i < vector_.size(); ++i) { 176 vector_[i] |= other.vector_[i]; 177 } 178 } 179 180 // Set the bitvector to all zeros. SetToZero()181 void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); } 182 183 bool operator==(const BitVector& other) const { 184 return vector_ == other.vector_; 185 } 186 bool operator!=(const BitVector& other) const { 187 return vector_ != other.vector_; 188 } 189 190 private: 191 using Word = uint64; 192 static constexpr size_t kBits = 64; 193 194 // Number of bits in the bitvector. 195 size_t size_; 196 197 std::vector<Word> vector_; 198 }; 199 200 // Return the bitvector storing the reachability-to of the given instruction. GetBitVector(const HloInstruction * instruction)201 const BitVector& GetBitVector(const HloInstruction* instruction) const { 202 return GetBitVector(GetIndex(instruction)); 203 } GetBitVector(const HloInstruction * instruction)204 BitVector& GetBitVector(const HloInstruction* instruction) { 205 return GetBitVector(GetIndex(instruction)); 206 } 207 GetBitVector(Index index)208 const BitVector& GetBitVector(Index index) const { 209 return bit_vectors_[index.v]; 210 } GetBitVector(Index index)211 BitVector& GetBitVector(Index index) { return bit_vectors_[index.v]; } 212 213 // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. 214 void SetReachabilityToUnionHelper( 215 absl::Span<const HloInstruction* const> inputs, 216 const HloInstruction* instruction, BitVector* bit_vector); 217 GetKey(const HloInstruction * instruction)218 uint64 GetKey(const HloInstruction* instruction) const { 219 uint64 unique_id = absl::bit_cast<uint32>(instruction->unique_id()); 220 uint64 module_id = 221 absl::bit_cast<uint32>(instruction->parent()->parent()->unique_id()); 222 return (module_id << 32) | unique_id; 223 } 224 // Return the index of the given instruction. GetIndexInternal(const HloInstruction * instruction)225 int GetIndexInternal(const HloInstruction* instruction) const { 226 return FindOrDie(indices_, GetKey(instruction)); 227 } 228 229 // The number of instructions in the reachability map. 230 const size_t size_; 231 232 // Dense assignment from HloInstruction::unique_id to number. These numbers 233 // index into the bit_vectors_ vector and into the bits within a BitVector. 234 absl::flat_hash_map<uint64, int> indices_; 235 236 // Bitvectors holding the reachability to each instruction. The bit vector for 237 // instruction X includes ones for each instruction which X is reachable from. 238 std::vector<BitVector> bit_vectors_; 239 240 // A temporary used by SetReachabilityToUnion to avoid an allocation with each 241 // call to the method. 242 BitVector tmp_bit_vector_; 243 }; 244 245 } // namespace xla 246 247 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 248