1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Call graph for an HLO module. 17 18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ 20 21 #include <ostream> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/flat_hash_set.h" 25 #include "tensorflow/compiler/xla/service/hlo_computation.h" 26 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 27 #include "tensorflow/compiler/xla/service/hlo_module.h" 28 29 namespace xla { 30 31 // The context in which a computation is called by another computation. 32 enum class CallContext { 33 // In a parallel context the computation is applied to each element of the 34 // array argument(s). kMap and kReduce instructions call computations in 35 // parallel context. 36 kParallel, 37 38 // In a sequential context the computation is applied to the entire argument 39 // shape(s). kCall and kWhile (body and condition) call computations in 40 // sequential context. 41 kSequential, 42 43 // A computation is called from both a parallel and sequential context. 44 kBoth, 45 46 // During call graph construction kNone is used to indicate that the context 47 // has not been determined. This is the top value for the context 48 // lattice. After construction, no call sites or call graph nodes should have 49 // this value. 50 kNone 51 }; 52 53 string CallContextToString(CallContext context); 54 std::ostream& operator<<(std::ostream& out, const CallContext& context); 55 56 CallContext GetInstructionCallContext(HloOpcode opcode); 57 58 // Represents an HLO instruction which calls one or more computations. 59 class CallSite { 60 public: CallSite(HloInstruction * instruction,const std::vector<HloComputation * > & called_computations,CallContext context)61 CallSite(HloInstruction* instruction, 62 const std::vector<HloComputation*>& called_computations, 63 CallContext context) 64 : instruction_(CHECK_NOTNULL(instruction)), 65 called_computations_(called_computations), 66 context_(context) {} 67 68 // Returns the instruction associated with this call site. instruction()69 HloInstruction* instruction() const { return instruction_; } 70 71 // Returns the computations called at this call site. called_computations()72 const std::vector<HloComputation*>& called_computations() const { 73 return called_computations_; 74 } 75 76 // Returns the context in which computations are called at this call site. context()77 CallContext context() const { return context_; } 78 79 string ToString() const; 80 81 private: 82 // The calling instruction. 83 HloInstruction* instruction_; 84 85 // The computations called by this callsite. 86 const std::vector<HloComputation*> called_computations_; 87 88 // The context in which the computations are called. 89 const CallContext context_; 90 }; 91 92 // A node in the call graph representing an HLO computation. 93 class CallGraphNode { 94 public: 95 CallGraphNode(HloComputation* computation); 96 97 // Returns the computation represented by this call graph node. computation()98 HloComputation* computation() const { return computation_; } 99 100 // Returns the call sites in this computation. These are the instructions in 101 // this computation which call other computations. callsites()102 const std::vector<CallSite>& callsites() const { return callsites_; } 103 104 // Returns the callsite associated with the given instruction. If this 105 // instruction calls no computations nullptr is returned. 106 // Prerequisite: instruction is in the computation associated with this call 107 // graph node. 108 const CallSite* GetCallSite(const HloInstruction* instruction) const; 109 110 // Returns the computations called by this computation. callees()111 const std::vector<HloComputation*>& callees() const { return callees_; } 112 113 // Returns the call sites in other computations which call this computation. caller_callsites()114 const std::vector<CallSite>& caller_callsites() const { 115 return caller_callsites_; 116 } 117 118 // Returns the computations which call this computation. callers()119 const std::vector<HloComputation*>& callers() const { return callers_; } 120 121 // Returns the context in which this computation is called. context()122 CallContext context() const { return context_; } 123 124 // Returns the depth of this node in the call graph. The depth is defined as 125 // the length of the longest call chain from a computation with no callers 126 // (usually the entry computation node) to this node. depth()127 int depth() const { return depth_; } 128 129 string ToString() const; 130 131 private: 132 // Only CallGraph can modify CallGraphNode. 133 friend class CallGraph; 134 135 // Sets the context in which this computation is called. set_context(CallContext value)136 void set_context(CallContext value) { context_ = value; } 137 138 // Sets the depth of this node in the graph. set_depth(int value)139 void set_depth(int value) { depth_ = value; } 140 141 // Adds a callsite which calls this computation. Updates callers to include 142 // the calling computation. 143 void AddCallerCallSite(const CallSite& caller_callsite); 144 145 // If instruction calls any computations adds a call site for this instruction 146 // to the call graph node. If the instruction calls no computations then no 147 // call site is added. 148 void AddCallSiteForInstruction(HloInstruction* instruction); 149 150 // Computation represented by this call graph node. 151 HloComputation* computation_; 152 153 // The computations called by this computation. The vector is used for a 154 // stable ordering and the set enables fast membership testing. 155 std::vector<HloComputation*> callees_; 156 absl::flat_hash_set<HloComputation*> callee_set_; 157 158 // The computations which call this computation. The vector is used for a 159 // stable ordering and the set enables fast membership testing. 160 std::vector<HloComputation*> callers_; 161 absl::flat_hash_set<HloComputation*> caller_set_; 162 163 // The call sites in this computation 164 std::vector<CallSite> callsites_; 165 166 // The map from instruction to index in callsites_ for looking up the callsite 167 // (if any) associated with a particular instruction in this computation. 168 absl::flat_hash_map<const HloInstruction*, int64> callsite_instructions_; 169 170 // The call sites in other computations which call this computation. 171 std::vector<CallSite> caller_callsites_; 172 173 // The context in which this computation is called. 174 CallContext context_ = CallContext::kNone; 175 176 // The depth of this node in the call graph. 177 int depth_ = 0; 178 }; 179 180 // The call graph for an HLO module. The graph includes a node for each 181 // computation in the module. 182 class CallGraph { 183 public: 184 using VisitorFunction = std::function<Status(const CallGraphNode&)>; 185 186 // Builds and returns a call graph for the given HLO module. 187 static std::unique_ptr<CallGraph> Build(const HloModule* module); 188 189 // Returns the node associated with the given computation. 190 const CallGraphNode& GetNode(const HloComputation* computation) const; 191 CallGraphNode& GetNode(const HloComputation* computation); 192 193 // Returns the vector of all nodes in the call graph. nodes()194 const std::vector<CallGraphNode>& nodes() const { return nodes_; } 195 196 // Calls the given function on each node in the call graph. Nodes are visited 197 // in post order (callees before callers). If visit_unreachable_nodes is true 198 // then all nodes in the call graph are visited. Otherwise only those nodes 199 // reachable from the entry computation are visited. 200 Status VisitNodes(const VisitorFunction& visitor_func, 201 bool visit_unreachable_nodes = true) const; 202 203 // Returns true if 'a' dominates 'b' in the call graph. Computation 'a' 204 // dominates computation 'b' iff all callgraph paths in the caller-to-callee 205 // direction from a root computation to 'b' pass through computation 206 // 'a'. Trivially, a computation dominates itself. 207 bool Dominates(const HloComputation* a, const HloComputation* b) const; 208 209 // Returns whether 'instruction' is contained in 'computation' either directly 210 // ('instruction->parent' is 'computation') or indirectly ('computation' 211 // dominates 'instruction->parent' in the call graph). InstructionIsNestedIn(const HloInstruction * instruction,const HloComputation * computation)212 bool InstructionIsNestedIn(const HloInstruction* instruction, 213 const HloComputation* computation) const { 214 return Dominates(computation, instruction->parent()); 215 } 216 217 // Returns the nearest call graph ancestors of instructions 'a' and 'b' for 218 // which the ancestors are in the same computation. An instruction is an call 219 // graph ancestor of 'a' if the instruction calls the computation containing 220 // 'a' either directly or transitively. Degeneratively an instruction is an 221 // ancestor of itself. nullptr is returned if there is no common ancestor or 222 // if the caller chain of 'a' or 'b' diverges (has multiple callers) before 223 // the nearest common ancestor. 224 // 225 // Example: 226 // 227 // Entry computation: 228 // %x = Call(A, {Constant(42.0)}) 229 // %y = Call(B, {%x}) 230 // 231 // Computation A: 232 // %a = Negate(Param()) 233 // 234 // Computation B: 235 // %b = Exp(Param()); 236 // 237 // If called with %a and %b, this function would return (%x, %y). %x is an 238 // ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same 239 // computation. 240 std::pair<HloInstruction*, HloInstruction*> NearestAncestorsInSameComputation( 241 HloInstruction* a, HloInstruction* b) const; 242 243 // Returns whether the call graph is flattened. A call graph is flattened if 244 // every computation called in a sequential context (eg, kWhile or kCall) has 245 // zero or one callsite, and no computation is called from both a parallel and 246 // sequential context. The call graph of a module can be flattened with 247 // FlattenCallGraph. 248 bool IsFlattened() const; 249 250 // Returns a vector of instructions calling the passed computation. 251 // (Often a vector of size 1.) 252 std::vector<HloInstruction*> GetComputationCallers(HloComputation* c); 253 254 string ToString() const; 255 256 private: 257 CallGraph(const HloModule* module); 258 259 // Not copyable. 260 CallGraph(const CallGraph&) = delete; 261 CallGraph& operator=(const CallGraph&) = delete; 262 263 // Sets the call contexts for every node in the graph. 264 void SetCallContexts(); 265 266 // Sets the call node depths for every node in the graph. 267 void SetNodeDepths(); 268 269 // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS 270 // post order (callee before caller) calling visitor_func on each node. Adds 271 // nodes to 'visited' as each node is visited. Skips nodes already in 272 // 'visited'. 273 Status VisitNodesInternal( 274 const VisitorFunction& visitor_func, const CallGraphNode& node, 275 absl::flat_hash_set<const CallGraphNode*>* visited) const; 276 277 // Recursive helper for computing whether 'a' dominates 'b' in the call 278 // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), 279 // and 'visited' is the set of computations which have been visited. 280 bool DominatesHelper( 281 const HloComputation* a, const HloComputation* b, 282 absl::flat_hash_set<const HloComputation*>* visited) const; 283 284 // The HLO module represented by this call graph. 285 const HloModule* module_ = nullptr; 286 287 // Vector of all nodes in the call graph. 288 std::vector<CallGraphNode> nodes_; 289 290 // Map from HLO computation to the index of the corresponding call graph node 291 // in nodes_. 292 absl::flat_hash_map<const HloComputation*, int64> node_indices_; 293 }; 294 295 } // namespace xla 296 297 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ 298