1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Call graph for an HLO module. 17 18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ 20 21 #include <ostream> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/flat_hash_set.h" 25 #include "tensorflow/compiler/xla/service/hlo_computation.h" 26 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 27 #include "tensorflow/compiler/xla/service/hlo_module.h" 28 29 namespace xla { 30 31 // The context in which a computation is called by another computation. 32 enum class CallContext { 33 // In an embedded call context, the body of the function cannot allocate 34 // buffers. 35 kEmbedded, 36 37 // A control flow call context can allocate buffers. 38 kControlFlow, 39 40 // A computation is called from both an embedded and control flow context. 41 kBoth, 42 43 // During call graph construction kNone is used to indicate that the context 44 // has not been determined. This is the top value for the context 45 // lattice. After construction, no call sites or call graph nodes should have 46 // this value. 47 kNone 48 }; 49 50 std::string CallContextToString(CallContext context); 51 std::ostream& operator<<(std::ostream& out, const CallContext& context); 52 53 CallContext GetInstructionCallContext(HloOpcode opcode); 54 55 // Represents an HLO instruction which calls one or more computations. 56 class CallSite { 57 public: CallSite(HloInstruction * instruction,absl::Span<HloComputation * const> called_computations,CallContext context)58 CallSite(HloInstruction* instruction, 59 absl::Span<HloComputation* const> called_computations, 60 CallContext context) 61 : instruction_(CHECK_NOTNULL(instruction)), 62 called_computations_(called_computations.begin(), 63 called_computations.end()), 64 context_(context) {} 65 66 // Returns the instruction associated with this call site. instruction()67 HloInstruction* instruction() const { return instruction_; } 68 69 // Returns the computations called at this call site. called_computations()70 absl::Span<HloComputation* const> called_computations() const { 71 return called_computations_; 72 } 73 74 // Returns the context in which computations are called at this call site. context()75 CallContext context() const { return context_; } 76 77 std::string ToString() const; 78 79 private: 80 // The calling instruction. 81 HloInstruction* instruction_; 82 83 // The computations called by this callsite. 84 const absl::InlinedVector<HloComputation*, 2> called_computations_; 85 86 // The context in which the computations are called. 87 const CallContext context_; 88 }; 89 90 // A node in the call graph representing an HLO computation. 91 class CallGraphNode { 92 public: 93 CallGraphNode(HloComputation* computation); 94 95 // Returns the computation represented by this call graph node. computation()96 HloComputation* computation() const { return computation_; } 97 98 // Returns the call sites in this computation. These are the instructions in 99 // this computation which call other computations. callsites()100 absl::Span<const CallSite> callsites() const { return callsites_; } 101 102 // Returns the callsite associated with the given instruction. If this 103 // instruction calls no computations nullptr is returned. 104 // Prerequisite: instruction is in the computation associated with this call 105 // graph node. 106 const CallSite* GetCallSite(const HloInstruction* instruction) const; 107 108 // Returns the computations called by this computation. callees()109 absl::Span<HloComputation* const> callees() const { return callees_; } 110 111 // Returns the call sites in other computations which call this computation. caller_callsites()112 absl::Span<const CallSite> caller_callsites() const { 113 return caller_callsites_; 114 } 115 116 // Returns the computations which call this computation. callers()117 absl::Span<HloComputation* const> callers() const { return callers_; } 118 119 // Returns the context in which this computation is called. context()120 CallContext context() const { return context_; } 121 122 // Returns the depth of this node in the call graph. The depth is defined as 123 // the length of the longest call chain from a computation with no callers 124 // (usually the entry computation node) to this node. depth()125 int depth() const { return depth_; } 126 127 std::string ToString() const; 128 129 CallGraphNode(const CallGraphNode&) = delete; 130 CallGraphNode& operator=(const CallGraphNode&) = delete; 131 CallGraphNode(CallGraphNode&&) = default; 132 CallGraphNode& operator=(CallGraphNode&&) = default; 133 134 private: 135 // Only CallGraph can modify CallGraphNode. 136 friend class CallGraph; 137 138 // Sets the context in which this computation is called. set_context(CallContext value)139 void set_context(CallContext value) { context_ = value; } 140 141 // Sets the depth of this node in the graph. set_depth(int value)142 void set_depth(int value) { depth_ = value; } 143 144 // Adds a callsite which calls this computation. Updates callers to include 145 // the calling computation. 146 void AddCallerCallSite(const CallSite& caller_callsite); 147 148 // If instruction calls any computations adds a call site for this instruction 149 // to the call graph node. If the instruction calls no computations then no 150 // call site is added. 151 void AddCallSiteForInstruction(HloInstruction* instruction); 152 153 // Computation represented by this call graph node. 154 HloComputation* computation_; 155 156 // The computations called by this computation. The vector is used for a 157 // stable ordering and the set enables fast membership testing. 158 absl::InlinedVector<HloComputation*, 1> callees_; 159 absl::flat_hash_set<HloComputation*> callee_set_; 160 161 // The computations which call this computation. The vector is used for a 162 // stable ordering and the set enables fast membership testing. 163 absl::InlinedVector<HloComputation*, 1> callers_; 164 absl::flat_hash_set<HloComputation*> caller_set_; 165 166 // The call sites in this computation 167 absl::InlinedVector<CallSite, 1> callsites_; 168 169 // The map from instruction to index in callsites_ for looking up the callsite 170 // (if any) associated with a particular instruction in this computation. 171 absl::flat_hash_map<const HloInstruction*, int64_t> callsite_instructions_; 172 173 // The call sites in other computations which call this computation. 174 absl::InlinedVector<CallSite, 1> caller_callsites_; 175 176 // The context in which this computation is called. 177 CallContext context_ = CallContext::kNone; 178 179 // The depth of this node in the call graph. 180 int depth_ = 0; 181 }; 182 183 // The call graph for an HLO module. The graph includes a node for each 184 // computation in the module. 185 class CallGraph { 186 public: 187 using VisitorFunction = std::function<Status(const CallGraphNode&)>; 188 189 // Builds and returns a call graph for the given HLO module. 190 static std::unique_ptr<CallGraph> Build(const HloModule* module); 191 192 // Returns the node associated with the given computation. 193 const CallGraphNode& GetNode(const HloComputation* computation) const; 194 CallGraphNode& GetNode(const HloComputation* computation); 195 196 // Returns the vector of all nodes in the call graph. nodes()197 const std::vector<CallGraphNode>& nodes() const { return nodes_; } 198 199 // Calls the given function on each node in the call graph. Nodes are visited 200 // in post order (callees before callers). If visit_unreachable_nodes is true 201 // then all nodes in the call graph are visited. Otherwise only those nodes 202 // reachable from the entry computation are visited. 203 Status VisitNodes(const VisitorFunction& visitor_func, 204 bool visit_unreachable_nodes = true) const; 205 206 // Returns true if 'a' dominates 'b' in the call graph. Computation 'a' 207 // dominates computation 'b' iff all callgraph paths in the caller-to-callee 208 // direction from a root computation to 'b' pass through computation 209 // 'a'. Trivially, a computation dominates itself. 210 bool Dominates(const HloComputation* a, const HloComputation* b) const; 211 212 // Returns whether 'instruction' is contained in 'computation' either directly 213 // ('instruction->parent' is 'computation') or indirectly ('computation' 214 // dominates 'instruction->parent' in the call graph). InstructionIsNestedIn(const HloInstruction * instruction,const HloComputation * computation)215 bool InstructionIsNestedIn(const HloInstruction* instruction, 216 const HloComputation* computation) const { 217 return Dominates(computation, instruction->parent()); 218 } 219 220 // Returns the nearest call graph ancestors of instructions 'a' and 'b' for 221 // which the ancestors are in the same computation. An instruction is an call 222 // graph ancestor of 'a' if the instruction calls the computation containing 223 // 'a' either directly or transitively. Degeneratively an instruction is an 224 // ancestor of itself. nullptr is returned if there is no common ancestor or 225 // if the caller chain of 'a' or 'b' diverges (has multiple callers) before 226 // the nearest common ancestor. 227 // 228 // Example: 229 // 230 // Entry computation: 231 // %x = Call(A, {Constant(42.0)}) 232 // %y = Call(B, {%x}) 233 // 234 // Computation A: 235 // %a = Negate(Param()) 236 // 237 // Computation B: 238 // %b = Exp(Param()); 239 // 240 // If called with %a and %b, this function would return (%x, %y). %x is an 241 // ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same 242 // computation. 243 std::pair<HloInstruction*, HloInstruction*> NearestAncestorsInSameComputation( 244 HloInstruction* a, HloInstruction* b) const; 245 246 // Returns whether the call graph is flattened. A call graph is flattened if 247 // every computation called in a sequential context (eg, kWhile or kCall) has 248 // zero or one callsite, and no computation is called from both a parallel and 249 // sequential context. The call graph of a module can be flattened with 250 // FlattenCallGraph. 251 bool IsFlattened() const; 252 253 // Returns a vector of instructions calling the passed computation. 254 // (Often a vector of size 1.) 255 std::vector<HloInstruction*> GetComputationCallers(HloComputation* c) const; 256 257 std::string ToString() const; 258 259 private: 260 CallGraph(const HloModule* module); 261 262 // Not copyable. 263 CallGraph(const CallGraph&) = delete; 264 CallGraph& operator=(const CallGraph&) = delete; 265 266 // Sets the call contexts for every node in the graph. 267 void SetCallContexts(); 268 269 // Sets the call node depths for every node in the graph. 270 void SetNodeDepths(); 271 272 // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS 273 // post order (callee before caller) calling visitor_func on each node. Adds 274 // nodes to 'visited' as each node is visited. Skips nodes already in 275 // 'visited'. 276 Status VisitNodesInternal( 277 const VisitorFunction& visitor_func, const CallGraphNode& node, 278 absl::flat_hash_set<const CallGraphNode*>* visited) const; 279 280 // Recursive helper for computing whether 'a' dominates 'b' in the call 281 // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), 282 // and 'visited' is the set of computations which have been visited. 283 bool DominatesHelper( 284 const HloComputation* a, const HloComputation* b, 285 absl::flat_hash_set<const HloComputation*>* visited) const; 286 287 // The HLO module represented by this call graph. 288 const HloModule* module_ = nullptr; 289 290 // Vector of all nodes in the call graph. 291 std::vector<CallGraphNode> nodes_; 292 293 // Map from HLO computation to the index of the corresponding call graph node 294 // in nodes_. 295 absl::flat_hash_map<const HloComputation*, int64_t> node_indices_; 296 }; 297 298 } // namespace xla 299 300 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ 301