1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 18 19 #include <list> 20 #include <memory> 21 #include <unordered_map> 22 #include <unordered_set> 23 24 #include "tensorflow/core/framework/node_def.pb.h" 25 #include "tensorflow/core/framework/step_stats.pb.h" 26 #include "tensorflow/core/grappler/costs/cost_estimator.h" 27 #include "tensorflow/core/grappler/costs/graph_properties.h" 28 #include "tensorflow/core/grappler/costs/op_context.h" 29 #include "tensorflow/core/grappler/costs/virtual_placer.h" 30 #include "tensorflow/core/grappler/grappler_item.h" 31 32 namespace tensorflow { 33 namespace grappler { 34 35 struct NodeState { 36 // A node (i.e., an op) takes a set of input:port pairs and produces 37 // a set of output ports. 38 39 // Cross references to input and output nodes from graphdef. 40 std::vector<std::pair<const NodeDef*, int>> inputs; // Input, port pairs. 41 // List of output nodes (a list of nodes that takes this output port as input) 42 // keyed by port_num. Note that port_num -1 is used for control dependency. 43 std::unordered_map<int, std::vector<const NodeDef*>> outputs; 44 45 // Info from GraphProperties. 46 std::vector<OpInfo::TensorProperties> input_properties; 47 std::vector<OpInfo::TensorProperties> output_properties; 48 49 // Canonical device name used within VirtualScheduler. 50 string device_name; 51 52 // States updated as scheduling nodes. 53 int num_inputs_ready; 54 std::unordered_map<int, int> num_outputs_executed; 55 Costs::Duration time_ready; 56 Costs::Duration time_scheduled; 57 Costs::Duration time_finished; 58 // Time that all the consumers are executed (hence, no need to keep this 59 // output in memory), keyed by port_num. 60 std::unordered_map<int, Costs::Duration> time_no_references; 61 62 // Note that a node may have multiple output ports. The length of outputs, 63 // num_outputs_executed, and time_no_references should be 64 // identical when a NodeState is fully initialized. 65 // They should be 1 + output_properties.size() as we add [-1] for control 66 // dependency. 67 68 // Node will be ready to be executed at time_ready, scheduled at 69 // time_scheduled, and finishes execution at time_finished. 70 // Each output port uses up memory space from time_scheduled to its 71 // time_no_references. 72 73 Costs node_costs; // Node costs per execution TotalNodeCostsNodeState74 Costs TotalNodeCosts() const { 75 return MultiplyCosts(node_costs, execution_count); 76 } 77 // How many times this node has been executed, e.g. in a while loop. 78 int execution_count; 79 80 // Output shape incompatible between shape annotation and shape inference. 81 bool shape_incompatible; 82 NodeStateNodeState83 NodeState() { 84 num_inputs_ready = 0; 85 time_ready = Costs::Duration::max(); 86 time_scheduled = Costs::Duration::max(); 87 time_finished = Costs::Duration::max(); 88 execution_count = 0; 89 shape_incompatible = false; 90 // Note that num_outputs_executed and time_no_references are not initialized 91 // here, since we don't know the size (i.e., # outputs for this node). 92 } 93 }; 94 95 struct DeviceState { 96 // Nodes executed on this device in execution order. 97 std::vector<const NodeDef*> nodes_executed; 98 99 struct NodePairHash { 100 public: operatorDeviceState::NodePairHash101 const std::size_t operator()( 102 const std::pair<const NodeDef*, int>& element) const { 103 return std::hash<const NodeDef*>()(element.first); 104 } 105 }; 106 107 // Nodes currently allocated in memory: set of NodeDef* and port_num pairs 108 // so that we can track which output of the node is in memory. 109 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 110 nodes_in_memory; 111 112 // Nodes allocated in memory persistently: e.g., Variables. 113 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 114 persistent_nodes; 115 116 // Snapshot of nodes_in_memory, when memory usage is at peak. 117 // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs. 118 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 119 mem_usage_snapshot_at_peak; 120 121 Costs device_costs; 122 std::map<string, Costs> op_to_cost; // Per-op cost. 123 124 int64 memory_usage; // Current temporary memory usage 125 int64 max_memory_usage; // Max temporary memory usage 126 127 // Shape annotation statistics. 128 struct ShapeAnnotationStats { 129 // Number of ops with shape annotated. 130 int64 num_ops_annotated = 0; 131 // Number of ops executed multiple times (e.g. in a loop). 132 int64 num_ops_executed_more_than_once = 0; 133 // Number of ops executed: account for execution count. 134 int64 num_ops_executed = 0; 135 // Number of ops with dynamic shapes (e.g. shape changes in a loop). 136 int64 num_ops_with_dynamic_shapes = 0; 137 // Number of ops with incompatible shapes between annotation and shape 138 // inference. 139 int64 num_ops_with_incompatible_shapes = 0; 140 } shape_annotation_stats; 141 DeviceStateDeviceState142 DeviceState() { 143 device_costs = Costs::ZeroCosts(); 144 device_costs.num_ops_total = 0; 145 memory_usage = 0; 146 max_memory_usage = 0; 147 } 148 GetCurrTimeDeviceState149 Costs::Duration GetCurrTime() const { return device_costs.execution_time; } 150 }; 151 152 // ReadyNodeManager (abstract class): 153 // Keeps ready nodes and picks the best one to be scheduled. 154 class ReadyNodeManager { 155 public: ReadyNodeManager()156 ReadyNodeManager() {} ~ReadyNodeManager()157 virtual ~ReadyNodeManager() {} Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)158 virtual Status Init( 159 const std::unordered_map<const NodeDef*, NodeState>* node_map) { 160 return Status::OK(); 161 } 162 virtual void AddNode(const NodeDef* node) = 0; 163 virtual const NodeDef* GetCurrNode() = 0; 164 virtual void RemoveCurrNode() = 0; 165 virtual bool Empty() const = 0; 166 }; 167 168 class FIFOManager : public ReadyNodeManager { 169 public: FIFOManager()170 FIFOManager() : ReadyNodeManager() {} ~FIFOManager()171 ~FIFOManager() override {} AddNode(const NodeDef * node)172 void AddNode(const NodeDef* node) override { nodes_.push_back(node); } GetCurrNode()173 const NodeDef* GetCurrNode() override { 174 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; 175 return nodes_.front(); 176 } RemoveCurrNode()177 void RemoveCurrNode() override { nodes_.pop_front(); } Empty()178 bool Empty() const override { return nodes_.empty(); } 179 180 private: 181 std::list<const NodeDef*> nodes_; 182 }; 183 184 // The LIFOManager schedules nodes by returning the last one added to the 185 // scheduler. A node is executed and then its ready outputs are newly added to 186 // the scheduler, so the LIFOManager will return outputs to a node following 187 // that node's execution. 188 class LIFOManager : public ReadyNodeManager { 189 public: LIFOManager()190 LIFOManager() : ReadyNodeManager() {} ~LIFOManager()191 ~LIFOManager() override {} AddNode(const NodeDef * node)192 void AddNode(const NodeDef* node) override { nodes_.push_back(node); } 193 const NodeDef* GetCurrNode() override; 194 void RemoveCurrNode() override; Empty()195 bool Empty() const override { return nodes_.empty(); } 196 197 private: 198 std::list<const NodeDef*> nodes_; 199 // Keep track of the current node being executed by saving its position. 200 // Necessary because nodes may be added to the end of the list while a node is 201 // executing, and we want to remove the correct node (the one that is 202 // executing) rather than the new ones being added. 203 std::list<const NodeDef*>::iterator curr_pos_ = nodes_.end(); 204 }; 205 206 // Abstract class that maintains a heap/priority queue for scheduling ready 207 // nodes. Derived class needs to implement the Greater() function which returns 208 // the comparator for the heap. 209 class HeapReadyManager : public ReadyNodeManager { 210 public: 211 HeapReadyManager(); 212 Status Init( 213 const std::unordered_map<const NodeDef*, NodeState>* node_map) override; ~HeapReadyManager()214 ~HeapReadyManager() override {} AddNode(const NodeDef * node)215 void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); } 216 const NodeDef* GetCurrNode() override; 217 void RemoveCurrNode() override; 218 bool Empty() const override; 219 220 protected: 221 virtual std::function<bool(const NodeDef*, const NodeDef*)> Greater() = 0; 222 // Move all the nodes in the waiting_queue_ to nodes_. 223 void DrainWaitingQueue(); 224 225 // nodes_ is the main queue, where we construct heap, and the front is the 226 // current node. 227 std::vector<const NodeDef*> nodes_; 228 // Newly added nodes are added to waiting_queue_. That way, GetCurrNode(), 229 // which returns the front of the nodes_, always returns the same node, 230 // even if any of new nodes has time_ready smaller than the current node's. 231 std::vector<const NodeDef*> waiting_queue_; 232 // Comparator functor for heap; stl heap is max heap, so we use "greater than" 233 // functor for keeping the smallest time_ready node at the front of heap. 234 std::function<bool(const NodeDef*, const NodeDef*)> greater_; 235 236 // NodeState structure from VirtualScheduler to get time_ready of ready nodes. 237 // Not owned by FirstReadyManager. 238 const std::unordered_map<const NodeDef*, NodeState>* node_map_; 239 }; 240 241 // FirstReadyManager picks a node with the minimum time_ready value. 242 // Behavior is deterministic when there are more than one nodes with the minimum 243 // time_ready value with unique node names as the tie-breaker. 244 class FirstReadyManager : public HeapReadyManager { 245 public: FirstReadyManager()246 FirstReadyManager() : HeapReadyManager() {} ~FirstReadyManager()247 ~FirstReadyManager() override {} 248 249 protected: 250 std::function<bool(const NodeDef*, const NodeDef*)> Greater() override; 251 }; 252 253 // PriorityReadyManager uses the given node priorities when picking up next node 254 // from all the ready nodes. 255 class PriorityReadyManager : public HeapReadyManager { 256 public: PriorityReadyManager()257 PriorityReadyManager() : HeapReadyManager() {} ~PriorityReadyManager()258 ~PriorityReadyManager() override {} 259 void AddNode(const NodeDef* node) override; 260 261 // Note this should be called after Init(). 262 Status SetPriority(const std::unordered_map<string, int>& node_priority); 263 264 protected: 265 std::function<bool(const NodeDef*, const NodeDef*)> Greater() override; 266 267 private: 268 // A map from unique node name to priority. Lower number means higher 269 // priority. 270 std::unordered_map<string, int> node_priority_; 271 }; 272 273 // CompositeNodeManager has a few other NodeManagers: per-device LIFO for normal 274 // ops (neither _Send nor _Recv) and FirstReadyManagers for _Send ops and _Recv 275 // ops, and then it chooses FirstReady among the ops chosen from each 276 // internal NodeManagers. The objective is to maximize producer-consumer 277 // locality within device, while processing nodes across devices, including 278 // _Send and _Recv, fairly, in terms of their time_ready. 279 class CompositeNodeManager : public ReadyNodeManager { 280 public: 281 CompositeNodeManager(); ~CompositeNodeManager()282 ~CompositeNodeManager() override {} 283 284 Status Init( 285 const std::unordered_map<const NodeDef*, NodeState>* node_map) override; 286 void AddNode(const NodeDef* node) override; 287 const NodeDef* GetCurrNode() override; 288 void RemoveCurrNode() override; 289 bool Empty() const override; 290 291 private: 292 // Internal ready node managers: 293 // LIFO for normal ops to maximize producer consumer locality. 294 // One LIFO per device. 295 std::unordered_map<string, LIFOManager> ops_lifo_map_; 296 // FirstReady for send and recv. Handle send and recv separately ensures that 297 // send and recv do not block previously read ops with LIFO schedule. 298 FirstReadyManager send_manager_; 299 FirstReadyManager recv_manager_; 300 301 // NodeState structure from VirtualScheduler to get time_ready of ready nodes. 302 // Not owned by CompositeReadyManager. 303 const std::unordered_map<const NodeDef*, NodeState>* node_map_; 304 305 // Cached curr node. Set back to nullptr from RemoveCurrNode(). 306 const NodeDef* curr_node_; 307 }; 308 309 // Constructs a ready node manager from the given string. 310 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory( 311 const string& ready_node_manager); 312 313 // The virtual scheduler emulates execution of nodes in a graph, considering 314 // dependencies, device, etc. 315 class VirtualScheduler { 316 public: 317 // Does not take ownership of cluster or ready_nodes. 318 VirtualScheduler(const bool use_static_shapes, 319 const bool use_aggressive_shape_inference, Cluster* cluster, 320 ReadyNodeManager* ready_nodes, 321 std::unique_ptr<VirtualPlacer> placer); 322 323 // Initializes the scheduler for the specific grappler item. 324 // Should be called immediately after the c'tor or when the scheduler will be 325 // reused for a new grappler item. All internal states of the scheduler 326 // related to the previous grappler item will be reset/cleared. 327 // 328 // This function should be called at least once after the scheduler is 329 // constructed. An uninitialized or failed-to-initialize scheduler will cause 330 // undefined behavior. 331 Status Init(const GrapplerItem* item); 332 333 OpContext GetCurrNode() const; 334 335 // Returns true if there is any node to be scheduled. 336 bool MarkCurrNodeExecuted(const Costs& node_costs); 337 338 // Prints out summary of execution (timing, memory usage, etc.) 339 Costs Summary() const; 340 // Like the above, but writes detailed stats to RunMetadata. 341 // If metadata is nullptr, then just calls and return Summary(). 342 Costs Summary(RunMetadata* metadata); 343 // Generates RunMetadata's step_stats and partition_graphs fields from results 344 // of the virtual execution of the graph. 345 void GenerateRunMetadata(RunMetadata* metadata); 346 347 // Returns per device memory usage. 348 const std::unordered_map<string, int64> GetPeakMemoryUsage() const; 349 const std::unordered_map<string, int64> GetPersistentMemoryUsage() const; 350 351 // Returns VirtualScheduler (read only) device and node states. GetDeviceStates()352 const std::unordered_map<string, DeviceState>* GetDeviceStates() const { 353 return &device_; 354 } GetNodeStates()355 const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const { 356 return &node_map_; 357 } 358 enable_mem_usage_tracking()359 void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; } 360 361 private: 362 // Methods called from Init(). Fails if initialize_ is set. 363 void MaybeUpdateInputOutput(const NodeDef* node); 364 NodeState& GetNodeStateOrCreateIt(const NodeDef* node); 365 std::pair<const NodeDef*, const NodeDef*> CreateSendRecv( 366 const NodeDef* from, const NodeDef* to, const NodeDef* input_node, 367 const string& input_name); 368 string DeviceName(const NodeDef* node) const; 369 string SanitizedDeviceName(const NodeDef* node) const; 370 string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const; 371 372 // Helper methods. 373 void AddOutputNodesToReadyQueue(const NodeDef* node, 374 const Costs::Duration& curr_time); 375 376 // Scheduler states: 377 ReadyNodeManager* ready_nodes_; // Not owned. 378 std::unordered_map<const NodeDef*, NodeState> node_map_; 379 std::unordered_map<string, DeviceState> device_; 380 381 // Pool of NodeDefs for SendRecv and Identity ops created. 382 std::vector<std::unique_ptr<NodeDef>> additional_nodes_; 383 384 // Stats: 385 // Op counts with key with input shape. 386 // Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]" 387 std::map<string, int> op_counts_; 388 // Individual op costs with key with input shape. 389 // Integer field for execution time in micro seconds. 390 // Boolean field for whether the cost is accurate. 391 std::map<string, std::pair<int, bool>> op_costs_; 392 393 Costs graph_costs_; // Graph cost. 394 std::map<string, Costs> op_to_cost_; // Per-op cost. 395 396 // Auxiliary data structures for constructing NodeState and DeviceState. 397 std::unique_ptr<GraphProperties> graph_properties_; // Initialized in Init(). 398 Cluster* cluster_; // Not owned. 399 400 const GrapplerItem* grappler_item_; // Not owned. 401 bool use_static_shapes_; 402 bool initialized_; 403 bool track_mem_usage_snapshot_; 404 const bool use_aggressive_shape_inference_; 405 406 std::unique_ptr<VirtualPlacer> placer_; 407 }; 408 409 } // namespace grappler 410 } // end namespace tensorflow 411 412 #endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 413