1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 18 19 #include <cstdio> 20 #include <list> 21 #include <vector> 22 23 #include "absl/base/casts.h" 24 #include "absl/container/flat_hash_map.h" 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/map_util.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 36 // A class for representing reachability between HloInstructions. 37 // 38 // It has an adjacency matrix and it is up to the user of the class to set the 39 // adjacency matrix such that it represents reachability, i.e. such that it is 40 // transitive. That the graph be transitive is thus not an invariant of this 41 // class, but it is required for the name of the class and its methods to make 42 // sense. 43 class HloReachabilityMap { 44 public: 45 // Sets up a graph with no edges and where the nodes correspond to the given 46 // instructions. 47 explicit HloReachabilityMap( 48 absl::Span<const HloInstruction* const> instructions); 49 50 // Computes and returns the reachability between HLO instructions in the 51 // computation. The returned HloReachabilityMap is constructed such that 52 // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a 53 // directed path (from producer to consumer) from 'a' to 'b'. Both data 54 // dependencies (operands) and control dependencies are considered for 55 // reachability. Trivially an instruction is reachable from itself. 56 static std::unique_ptr<HloReachabilityMap> Build( 57 const HloComputation* computation); 58 59 // Set the reachability set of 'instruction' to the union of the reachability 60 // sets of 'inputs'. Upon return, IsReachable(x, instruction) where 61 // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true 62 // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from 63 // itself. Returns whether the reachability set of 'instruction' changed. 64 // 65 // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency 66 // vector in the internal graph of this HloReachabilityMap for the given 67 // instruction and does not transitively update any other part of the 68 // adjacency matrix. 69 bool SetReachabilityToUnion(absl::Span<const HloInstruction* const> inputs, 70 const HloInstruction* instruction); 71 72 // As above, but faster because it does not check if the reachability changed. 73 void FastSetReachabilityToUnion( 74 absl::Span<const HloInstruction* const> inputs, 75 const HloInstruction* instruction); 76 77 // Sets entry so that IsReachable(a, b) will return true 78 // 79 // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency 80 // matrix in the internal graph of this HloReachabilityMap to have an edge 81 // from a to b and does not transitively update any other part of the 82 // adjacency matrix. 83 void SetReachable(const HloInstruction* a, const HloInstruction* b); 84 85 // Updates the given reachability map after the immediate predecessor set 86 // (operands and control predecessors) of 'instruction' has changed. 87 void UpdateReachabilityThroughInstruction(const HloInstruction* instruction); 88 89 // Returns true if "b" is reachable from "a" 90 // 91 // Note that this function only correctly answers queries about reachability 92 // if the set of edges that have been provided to this class are transitive. 93 bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; 94 95 // Returns true if "b" is reachable from "a" or "a" is reachable from "b" 96 // 97 // Note that this function only correctly answers queries about reachability 98 // if the set of edges that have been provided to this class are transitive. 99 bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; 100 101 // Checks if an instruction is in the Reachability map. IsPresent(const HloInstruction * a)102 bool IsPresent(const HloInstruction* a) const { 103 return indices_.contains(GetKey(a)); 104 } 105 106 private: 107 // A bit-vector implementation specialized for this use case which provides a 108 // fast bitwise OR operation not available in tensorflow::gtl::BitMap. 109 class BitVector { 110 public: 111 BitVector() = default; BitVector(size_t size)112 BitVector(size_t size) 113 : size_(size), vector_((size + kBits - 1) / kBits, 0) {} 114 115 // Return the bit at the given index. Get(size_t index)116 bool Get(size_t index) const { 117 DCHECK(index >= 0 && index < size_); 118 return vector_[index / kBits] & (1ull << (index % kBits)); 119 } 120 121 // Set the bit at the given index. Set(size_t index)122 void Set(size_t index) { 123 DCHECK(index >= 0 && index < size_); 124 vector_[index / kBits] |= 1ull << (index % kBits); 125 } 126 127 // Set this bitvector to the Logical OR of this bitvector and 'other'. OrWith(const BitVector & other)128 void OrWith(const BitVector& other) { 129 for (size_t i = 0; i < vector_.size(); ++i) { 130 vector_[i] |= other.vector_[i]; 131 } 132 } 133 134 // Set the bitvector to all zeros. SetToZero()135 void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); } 136 137 bool operator==(const BitVector& other) const { 138 return vector_ == other.vector_; 139 } 140 bool operator!=(const BitVector& other) const { 141 return vector_ != other.vector_; 142 } 143 144 private: 145 using Word = uint64; 146 static const size_t kBits = 64; 147 148 // Number of bits in the bitvector. 149 size_t size_; 150 151 std::vector<Word> vector_; 152 }; 153 154 // Return the bitvector storing the reachability-to of the given instruction. GetBitVector(const HloInstruction * instruction)155 const BitVector& GetBitVector(const HloInstruction* instruction) const { 156 return bit_vectors_[GetIndex(instruction)]; 157 } GetBitVector(const HloInstruction * instruction)158 BitVector& GetBitVector(const HloInstruction* instruction) { 159 return bit_vectors_[GetIndex(instruction)]; 160 } 161 162 // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. 163 void SetReachabilityToUnionHelper( 164 absl::Span<const HloInstruction* const> inputs, 165 const HloInstruction* instruction, BitVector* bit_vector); 166 GetKey(const HloInstruction * instruction)167 uint64 GetKey(const HloInstruction* instruction) const { 168 uint64 unique_id = absl::bit_cast<uint32>(instruction->unique_id()); 169 uint64 module_id = 170 absl::bit_cast<uint32>(instruction->parent()->parent()->unique_id()); 171 return (module_id << 32) | unique_id; 172 } 173 // Return the index of the given instruction. The value is used to index into 174 // the vector of BitVectors and the BitVectors themselves. GetIndex(const HloInstruction * instruction)175 int GetIndex(const HloInstruction* instruction) const { 176 return FindOrDie(indices_, GetKey(instruction)); 177 } 178 179 // The number of instructions in the reachability map. 180 const size_t size_; 181 182 // Dense assignment from HloInstruction::unique_id to number. These numbers 183 // index into the bit_vectors_ vector and into the bits within a BitVector. 184 absl::flat_hash_map<uint64, int> indices_; 185 186 // Bitvectors holding the reachability to each instruction. The bit vector for 187 // instruction X includes ones for each instruction which X is reachable from. 188 std::vector<BitVector> bit_vectors_; 189 190 // A temporary used by SetReachabilityToUnion to avoid an allocation with each 191 // call to the method. 192 BitVector tmp_bit_vector_; 193 }; 194 195 } // namespace xla 196 197 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 198