• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
18 
19 #include <cstdio>
20 #include <list>
21 #include <vector>
22 
23 #include "absl/base/casts.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 // A class for representing reachability between HloInstructions.
37 //
38 // It has an adjacency matrix and it is up to the user of the class to set the
39 // adjacency matrix such that it represents reachability, i.e. such that it is
40 // transitive. That the graph be transitive is thus not an invariant of this
41 // class, but it is required for the name of the class and its methods to make
42 // sense.
43 class HloReachabilityMap {
44  public:
45   // An opaque index that clients can use to make repeated operations for the
46   // same instruction faster, by calling GetIndex once for the instruction,
47   // and then calling the variants of other interfaces that take Index arguments
48   // rather than HloInstruction* arguments.
49   struct Index {
50    public:
51     bool operator==(Index other) const { return v == other.v; }
52     bool operator!=(Index other) const { return v != other.v; }
53 
54    private:
55     friend class HloReachabilityMap;
56 
57     // Index assigned for a particular instruction.  The value is used to index
58     // into the vector of BitVectors and the BitVectors themselves.
59     int v;
60   };
61   // Sets up a graph with no edges and where the nodes correspond to the given
62   // instructions.
63   explicit HloReachabilityMap(
64       absl::Span<const HloInstruction* const> instructions);
65 
66   // Computes and returns the reachability between HLO instructions in the
67   // computation. The returned HloReachabilityMap is constructed such that
68   // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a
69   // directed path (from producer to consumer) from 'a' to 'b'. Both data
70   // dependencies (operands) and control dependencies are considered for
71   // reachability. Trivially an instruction is reachable from itself.
72   static std::unique_ptr<HloReachabilityMap> Build(
73       const HloComputation* computation);
74 
75   // Similar to the above Build operation except that it tries to identify
76   // paths between instructions that do not contain control instructions
77   // and multiple operands, i.e., b is_reachable a == true iff
78   // b = f(f(f(f(f(a), constant), constant), constant).
79   // Further, the only ops allowed in a path are basic math operations such
80   // as add, sub, mul, div.
81   static std::unique_ptr<HloReachabilityMap> BuildWithRestrictions(
82       const HloComputation* computation,
83       absl::FunctionRef<void(const HloInstruction*,
84                              std::vector<HloInstruction*>*)>
85           add_dependencies);
86 
87   // Set the reachability set of 'instruction' to the union of the reachability
88   // sets of 'inputs'. Upon return, IsReachable(x, instruction) where
89   // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true
90   // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from
91   // itself. Returns whether the reachability set of 'instruction' changed.
92   //
93   // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
94   // vector in the internal graph of this HloReachabilityMap for the given
95   // instruction and does not transitively update any other part of the
96   // adjacency matrix.
97   bool SetReachabilityToUnion(absl::Span<const HloInstruction* const> inputs,
98                               const HloInstruction* instruction);
99 
100   // As above, but faster because it does not check if the reachability changed.
101   void FastSetReachabilityToUnion(
102       absl::Span<const HloInstruction* const> inputs,
103       const HloInstruction* instruction);
104   // As above, but use Index instead if it's already looked up which is even
105   // faster since no hash map lookup will occur.
106   void FastSetReachabilityToUnion(absl::Span<const Index> input_indices,
107                                   Index index);
108 
GetIndex(const HloInstruction * instruction)109   Index GetIndex(const HloInstruction* instruction) const {
110     Index i;
111     i.v = FindOrDie(indices_, GetKey(instruction));
112     return i;
113   }
114 
115   // Sets entry so that IsReachable(a, b) will return true
116   //
117   // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
118   // matrix in the internal graph of this HloReachabilityMap to have an edge
119   // from a to b and does not transitively update any other part of the
120   // adjacency matrix.
SetReachable(const HloInstruction * a,const HloInstruction * b)121   void SetReachable(const HloInstruction* a, const HloInstruction* b) {
122     SetReachable(GetIndex(a), GetIndex(b));
123   }
124   void SetReachable(Index a, Index b);
125 
126   // Updates the given reachability map after the immediate predecessor set
127   // (operands and control predecessors) of 'instruction' has changed.
128   void UpdateReachabilityThroughInstruction(const HloInstruction* instruction);
129 
130   // Returns true if "b" is reachable from "a"
131   //
132   // Note that this function only correctly answers queries about reachability
133   // if the set of edges that have been provided to this class are transitive.
IsReachable(const HloInstruction * a,const HloInstruction * b)134   bool IsReachable(const HloInstruction* a, const HloInstruction* b) const {
135     return IsReachable(GetIndex(a), GetIndex(b));
136   }
IsReachable(Index a,Index b)137   bool IsReachable(Index a, Index b) const { return GetBitVector(b).Get(a.v); }
138 
139   // Returns true if "b" is reachable from "a" or "a" is reachable from "b"
140   //
141   // Note that this function only correctly answers queries about reachability
142   // if the set of edges that have been provided to this class are transitive.
IsConnected(const HloInstruction * a,const HloInstruction * b)143   bool IsConnected(const HloInstruction* a, const HloInstruction* b) const {
144     return IsConnected(GetIndex(a), GetIndex(b));
145   }
IsConnected(Index a,Index b)146   bool IsConnected(Index a, Index b) const {
147     return IsReachable(a, b) || IsReachable(b, a);
148   }
149 
150   // Checks if an instruction is in the Reachability map.
IsPresent(const HloInstruction * a)151   bool IsPresent(const HloInstruction* a) const {
152     return indices_.contains(GetKey(a));
153   }
154 
155   // Replace the instruction "original" with "replacement" in the reachability
156   // map.
157   void Replace(const HloInstruction* original,
158                const HloInstruction* replacement);
159 
160  private:
161   // A bit-vector implementation specialized for this use case which provides a
162   // fast bitwise OR operation not available in tensorflow::gtl::BitMap.
163   class BitVector {
164    public:
165     BitVector() = default;
BitVector(size_t size)166     BitVector(size_t size)
167         : size_(size), vector_((size + kBits - 1) / kBits, 0) {}
168 
169     // Return the bit at the given index.
Get(size_t index)170     bool Get(size_t index) const {
171       DCHECK(index >= 0 && index < size_);
172       return vector_[index / kBits] & (1ull << (index % kBits));
173     }
174 
175     // Set the bit at the given index.
Set(size_t index)176     void Set(size_t index) {
177       DCHECK(index >= 0 && index < size_);
178       vector_[index / kBits] |= 1ull << (index % kBits);
179     }
180 
181     // Set this bitvector to the Logical OR of this bitvector and 'other'.
OrWith(const BitVector & other)182     void OrWith(const BitVector& other) {
183       for (size_t i = 0; i < vector_.size(); ++i) {
184         vector_[i] |= other.vector_[i];
185       }
186     }
187 
188     // Set the bitvector to all zeros.
SetToZero()189     void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); }
190 
191     bool operator==(const BitVector& other) const {
192       return vector_ == other.vector_;
193     }
194     bool operator!=(const BitVector& other) const {
195       return vector_ != other.vector_;
196     }
197 
198    private:
199     using Word = uint64;
200     static constexpr size_t kBits = 64;
201 
202     // Number of bits in the bitvector.
203     size_t size_;
204 
205     std::vector<Word> vector_;
206   };
207 
208   // Return the bitvector storing the reachability-to of the given instruction.
GetBitVector(const HloInstruction * instruction)209   const BitVector& GetBitVector(const HloInstruction* instruction) const {
210     return GetBitVector(GetIndex(instruction));
211   }
GetBitVector(const HloInstruction * instruction)212   BitVector& GetBitVector(const HloInstruction* instruction) {
213     return GetBitVector(GetIndex(instruction));
214   }
215 
GetBitVector(Index index)216   const BitVector& GetBitVector(Index index) const {
217     return bit_vectors_[index.v];
218   }
GetBitVector(Index index)219   BitVector& GetBitVector(Index index) { return bit_vectors_[index.v]; }
220 
221   // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
222   void SetReachabilityToUnionHelper(
223       absl::Span<const HloInstruction* const> inputs, Index index);
224   void SetReachabilityToUnionHelper(absl::Span<const Index> input_indices,
225                                     Index index);
226 
GetKey(const HloInstruction * instruction)227   uint64 GetKey(const HloInstruction* instruction) const {
228     uint64 unique_id = absl::bit_cast<uint32>(instruction->unique_id());
229     uint64 module_id =
230         absl::bit_cast<uint32>(instruction->parent()->parent()->unique_id());
231     return (module_id << 32) | unique_id;
232   }
233   // Return the index of the given instruction.
GetIndexInternal(const HloInstruction * instruction)234   int GetIndexInternal(const HloInstruction* instruction) const {
235     return FindOrDie(indices_, GetKey(instruction));
236   }
237 
238   // The number of instructions in the reachability map.
239   const size_t size_;
240 
241   // Dense assignment from HloInstruction::unique_id to number. These numbers
242   // index into the bit_vectors_ vector and into the bits within a BitVector.
243   absl::flat_hash_map<uint64, int> indices_;
244 
245   // Bitvectors holding the reachability to each instruction. The bit vector for
246   // instruction X includes ones for each instruction which X is reachable from.
247   std::vector<BitVector> bit_vectors_;
248 
249   // A temporary used by SetReachabilityToUnion to avoid an allocation with each
250   // call to the method.
251   BitVector tmp_bit_vector_;
252 };
253 
254 }  // namespace xla
255 
256 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
257