1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 18 19 #include <cstdio> 20 #include <list> 21 #include <vector> 22 23 #include "absl/base/casts.h" 24 #include "absl/container/flat_hash_map.h" 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/map_util.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 36 // A class for representing reachability between HloInstructions. 37 // 38 // It has an adjacency matrix and it is up to the user of the class to set the 39 // adjacency matrix such that it represents reachability, i.e. such that it is 40 // transitive. That the graph be transitive is thus not an invariant of this 41 // class, but it is required for the name of the class and its methods to make 42 // sense. 43 class HloReachabilityMap { 44 public: 45 // An opaque index that clients can use to make repeated operations for the 46 // same instruction faster, by calling GetIndex once for the instruction, 47 // and then calling the variants of other interfaces that take Index arguments 48 // rather than HloInstruction* arguments. 49 struct Index { 50 public: 51 bool operator==(Index other) const { return v == other.v; } 52 bool operator!=(Index other) const { return v != other.v; } 53 54 private: 55 friend class HloReachabilityMap; 56 57 // Index assigned for a particular instruction. The value is used to index 58 // into the vector of BitVectors and the BitVectors themselves. 59 int v; 60 }; 61 // Sets up a graph with no edges and where the nodes correspond to the given 62 // instructions. 63 explicit HloReachabilityMap( 64 absl::Span<const HloInstruction* const> instructions); 65 66 // Computes and returns the reachability between HLO instructions in the 67 // computation. The returned HloReachabilityMap is constructed such that 68 // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a 69 // directed path (from producer to consumer) from 'a' to 'b'. Both data 70 // dependencies (operands) and control dependencies are considered for 71 // reachability. Trivially an instruction is reachable from itself. 72 static std::unique_ptr<HloReachabilityMap> Build( 73 const HloComputation* computation); 74 75 // Similar to the above Build operation except that it tries to identify 76 // paths between instructions that do not contain control instructions 77 // and multiple operands, i.e., b is_reachable a == true iff 78 // b = f(f(f(f(f(a), constant), constant), constant). 79 // Further, the only ops allowed in a path are basic math operations such 80 // as add, sub, mul, div. 81 static std::unique_ptr<HloReachabilityMap> BuildWithRestrictions( 82 const HloComputation* computation, 83 absl::FunctionRef<void(const HloInstruction*, 84 std::vector<HloInstruction*>*)> 85 add_dependencies); 86 87 // Set the reachability set of 'instruction' to the union of the reachability 88 // sets of 'inputs'. Upon return, IsReachable(x, instruction) where 89 // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true 90 // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from 91 // itself. Returns whether the reachability set of 'instruction' changed. 92 // 93 // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency 94 // vector in the internal graph of this HloReachabilityMap for the given 95 // instruction and does not transitively update any other part of the 96 // adjacency matrix. 97 bool SetReachabilityToUnion(absl::Span<const HloInstruction* const> inputs, 98 const HloInstruction* instruction); 99 100 // As above, but faster because it does not check if the reachability changed. 101 void FastSetReachabilityToUnion( 102 absl::Span<const HloInstruction* const> inputs, 103 const HloInstruction* instruction); 104 // As above, but use Index instead if it's already looked up which is even 105 // faster since no hash map lookup will occur. 106 void FastSetReachabilityToUnion(absl::Span<const Index> input_indices, 107 Index index); 108 GetIndex(const HloInstruction * instruction)109 Index GetIndex(const HloInstruction* instruction) const { 110 Index i; 111 i.v = FindOrDie(indices_, GetKey(instruction)); 112 return i; 113 } 114 115 // Sets entry so that IsReachable(a, b) will return true 116 // 117 // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency 118 // matrix in the internal graph of this HloReachabilityMap to have an edge 119 // from a to b and does not transitively update any other part of the 120 // adjacency matrix. SetReachable(const HloInstruction * a,const HloInstruction * b)121 void SetReachable(const HloInstruction* a, const HloInstruction* b) { 122 SetReachable(GetIndex(a), GetIndex(b)); 123 } 124 void SetReachable(Index a, Index b); 125 126 // Updates the given reachability map after the immediate predecessor set 127 // (operands and control predecessors) of 'instruction' has changed. 128 void UpdateReachabilityThroughInstruction(const HloInstruction* instruction); 129 130 // Returns true if "b" is reachable from "a" 131 // 132 // Note that this function only correctly answers queries about reachability 133 // if the set of edges that have been provided to this class are transitive. IsReachable(const HloInstruction * a,const HloInstruction * b)134 bool IsReachable(const HloInstruction* a, const HloInstruction* b) const { 135 return IsReachable(GetIndex(a), GetIndex(b)); 136 } IsReachable(Index a,Index b)137 bool IsReachable(Index a, Index b) const { return GetBitVector(b).Get(a.v); } 138 139 // Returns true if "b" is reachable from "a" or "a" is reachable from "b" 140 // 141 // Note that this function only correctly answers queries about reachability 142 // if the set of edges that have been provided to this class are transitive. IsConnected(const HloInstruction * a,const HloInstruction * b)143 bool IsConnected(const HloInstruction* a, const HloInstruction* b) const { 144 return IsConnected(GetIndex(a), GetIndex(b)); 145 } IsConnected(Index a,Index b)146 bool IsConnected(Index a, Index b) const { 147 return IsReachable(a, b) || IsReachable(b, a); 148 } 149 150 // Checks if an instruction is in the Reachability map. IsPresent(const HloInstruction * a)151 bool IsPresent(const HloInstruction* a) const { 152 return indices_.contains(GetKey(a)); 153 } 154 155 // Replace the instruction "original" with "replacement" in the reachability 156 // map. 157 void Replace(const HloInstruction* original, 158 const HloInstruction* replacement); 159 160 private: 161 // A bit-vector implementation specialized for this use case which provides a 162 // fast bitwise OR operation not available in tensorflow::gtl::BitMap. 163 class BitVector { 164 public: 165 BitVector() = default; BitVector(size_t size)166 BitVector(size_t size) 167 : size_(size), vector_((size + kBits - 1) / kBits, 0) {} 168 169 // Return the bit at the given index. Get(size_t index)170 bool Get(size_t index) const { 171 DCHECK(index >= 0 && index < size_); 172 return vector_[index / kBits] & (1ull << (index % kBits)); 173 } 174 175 // Set the bit at the given index. Set(size_t index)176 void Set(size_t index) { 177 DCHECK(index >= 0 && index < size_); 178 vector_[index / kBits] |= 1ull << (index % kBits); 179 } 180 181 // Set this bitvector to the Logical OR of this bitvector and 'other'. OrWith(const BitVector & other)182 void OrWith(const BitVector& other) { 183 for (size_t i = 0; i < vector_.size(); ++i) { 184 vector_[i] |= other.vector_[i]; 185 } 186 } 187 188 // Set the bitvector to all zeros. SetToZero()189 void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); } 190 191 bool operator==(const BitVector& other) const { 192 return vector_ == other.vector_; 193 } 194 bool operator!=(const BitVector& other) const { 195 return vector_ != other.vector_; 196 } 197 198 private: 199 using Word = uint64; 200 static constexpr size_t kBits = 64; 201 202 // Number of bits in the bitvector. 203 size_t size_; 204 205 std::vector<Word> vector_; 206 }; 207 208 // Return the bitvector storing the reachability-to of the given instruction. GetBitVector(const HloInstruction * instruction)209 const BitVector& GetBitVector(const HloInstruction* instruction) const { 210 return GetBitVector(GetIndex(instruction)); 211 } GetBitVector(const HloInstruction * instruction)212 BitVector& GetBitVector(const HloInstruction* instruction) { 213 return GetBitVector(GetIndex(instruction)); 214 } 215 GetBitVector(Index index)216 const BitVector& GetBitVector(Index index) const { 217 return bit_vectors_[index.v]; 218 } GetBitVector(Index index)219 BitVector& GetBitVector(Index index) { return bit_vectors_[index.v]; } 220 221 // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. 222 void SetReachabilityToUnionHelper( 223 absl::Span<const HloInstruction* const> inputs, Index index); 224 void SetReachabilityToUnionHelper(absl::Span<const Index> input_indices, 225 Index index); 226 GetKey(const HloInstruction * instruction)227 uint64 GetKey(const HloInstruction* instruction) const { 228 uint64 unique_id = absl::bit_cast<uint32>(instruction->unique_id()); 229 uint64 module_id = 230 absl::bit_cast<uint32>(instruction->parent()->parent()->unique_id()); 231 return (module_id << 32) | unique_id; 232 } 233 // Return the index of the given instruction. GetIndexInternal(const HloInstruction * instruction)234 int GetIndexInternal(const HloInstruction* instruction) const { 235 return FindOrDie(indices_, GetKey(instruction)); 236 } 237 238 // The number of instructions in the reachability map. 239 const size_t size_; 240 241 // Dense assignment from HloInstruction::unique_id to number. These numbers 242 // index into the bit_vectors_ vector and into the bits within a BitVector. 243 absl::flat_hash_map<uint64, int> indices_; 244 245 // Bitvectors holding the reachability to each instruction. The bit vector for 246 // instruction X includes ones for each instruction which X is reachable from. 247 std::vector<BitVector> bit_vectors_; 248 249 // A temporary used by SetReachabilityToUnion to avoid an allocation with each 250 // call to the method. 251 BitVector tmp_bit_vector_; 252 }; 253 254 } // namespace xla 255 256 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 257