1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 18 19 #include <list> 20 #include <memory> 21 #include <unordered_map> 22 #include <unordered_set> 23 24 #include "tensorflow/core/framework/node_def.pb.h" 25 #include "tensorflow/core/framework/step_stats.pb.h" 26 #include "tensorflow/core/grappler/costs/cost_estimator.h" 27 #include "tensorflow/core/grappler/costs/graph_properties.h" 28 #include "tensorflow/core/grappler/costs/op_context.h" 29 #include "tensorflow/core/grappler/costs/virtual_placer.h" 30 #include "tensorflow/core/grappler/grappler_item.h" 31 32 namespace tensorflow { 33 namespace grappler { 34 35 ABSL_CONST_INIT extern const char kAttrInputSrc[]; 36 ABSL_CONST_INIT extern const char kAttrSrcDevice[]; 37 ABSL_CONST_INIT extern const char kAttrDstDevice[]; 38 ABSL_CONST_INIT extern const char kAttrTensorName[]; 39 ABSL_CONST_INIT extern const char kChannelDevice[]; 40 ABSL_CONST_INIT extern const char kStreaming[]; 41 42 struct NodeState { 43 // A node (i.e., an op) takes a set of input:port pairs and produces 44 // a set of output ports. 45 46 // Cross references to input and output nodes from graphdef. 47 std::vector<std::pair<const NodeDef*, int>> inputs; // Input, port pairs. 48 // List of output nodes (a list of nodes that takes this output port as input) 49 // keyed by port_num. Note that port_num -1 is used for control dependency. 50 std::unordered_map<int, std::vector<const NodeDef*>> outputs; 51 52 // Info from GraphProperties. 53 std::vector<OpInfo::TensorProperties> input_properties; 54 std::vector<OpInfo::TensorProperties> output_properties; 55 56 // Canonical device name used within VirtualScheduler. 57 string device_name; 58 59 // States updated as scheduling nodes. 60 int num_inputs_ready; 61 std::unordered_map<int, int> num_outputs_executed; 62 Costs::Duration time_ready; 63 Costs::Duration time_scheduled; 64 Costs::Duration time_finished; 65 // Time that all the consumers are executed (hence, no need to keep this 66 // output in memory), keyed by port_num. 67 std::unordered_map<int, Costs::Duration> time_no_references; 68 69 // Note that a node may have multiple output ports. The length of outputs, 70 // num_outputs_executed, and time_no_references should be 71 // identical when a NodeState is fully initialized. 72 // They should be 1 + output_properties.size() as we add [-1] for control 73 // dependency. 74 75 // Node will be ready to be executed at time_ready, scheduled at 76 // time_scheduled, and finishes execution at time_finished. 77 // Each output port uses up memory space from time_scheduled to its 78 // time_no_references. 79 80 Costs node_costs; // Node costs per execution TotalNodeCostsNodeState81 Costs TotalNodeCosts() const { 82 return MultiplyCosts(node_costs, execution_count); 83 } 84 // How many times this node has been executed, e.g. in a while loop. 85 int execution_count; 86 87 // Output shape incompatible between shape annotation and shape inference. 88 bool shape_incompatible; 89 NodeStateNodeState90 NodeState() { 91 num_inputs_ready = 0; 92 time_ready = Costs::Duration::max(); 93 time_scheduled = Costs::Duration::max(); 94 time_finished = Costs::Duration::max(); 95 execution_count = 0; 96 shape_incompatible = false; 97 // Note that num_outputs_executed and time_no_references are not initialized 98 // here, since we don't know the size (i.e., # outputs for this node). 99 } 100 }; 101 102 struct DeviceState { 103 // Nodes executed on this device in execution order. 104 std::vector<const NodeDef*> nodes_executed; 105 106 struct NodePairHash { 107 public: operatorDeviceState::NodePairHash108 const std::size_t operator()( 109 const std::pair<const NodeDef*, int>& element) const { 110 return std::hash<const NodeDef*>()(element.first); 111 } 112 }; 113 114 // Nodes currently allocated in memory: set of NodeDef* and port_num pairs 115 // so that we can track which output of the node is in memory. 116 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 117 nodes_in_memory; 118 119 // Nodes allocated in memory persistently: e.g., Variables. 120 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 121 persistent_nodes; 122 123 // Snapshot of nodes_in_memory, when memory usage is at peak. 124 // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs. 125 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 126 mem_usage_snapshot_at_peak; 127 128 Costs device_costs; 129 std::map<string, Costs> op_to_cost; // Per-op cost. 130 131 int64 memory_usage; // Current temporary memory usage 132 int64 max_memory_usage; // Max temporary memory usage 133 134 // Shape annotation statistics. 135 struct ShapeAnnotationStats { 136 // Number of ops with shape annotated. 137 int64 num_ops_annotated = 0; 138 // Number of ops executed multiple times (e.g. in a loop). 139 int64 num_ops_executed_more_than_once = 0; 140 // Number of ops executed: account for execution count. 141 int64 num_ops_executed = 0; 142 // Number of ops with dynamic shapes (e.g. shape changes in a loop). 143 int64 num_ops_with_dynamic_shapes = 0; 144 // Number of ops with incompatible shapes between annotation and shape 145 // inference. 146 int64 num_ops_with_incompatible_shapes = 0; 147 } shape_annotation_stats; 148 DeviceStateDeviceState149 DeviceState() { 150 device_costs = Costs::ZeroCosts(); 151 device_costs.num_ops_total = 0; 152 memory_usage = 0; 153 max_memory_usage = 0; 154 } 155 GetCurrTimeDeviceState156 Costs::Duration GetCurrTime() const { return device_costs.execution_time; } 157 }; 158 159 // ReadyNodeManager (abstract class): 160 // Keeps ready nodes and picks the best one to be scheduled. 161 class ReadyNodeManager { 162 public: ReadyNodeManager()163 ReadyNodeManager() {} ~ReadyNodeManager()164 virtual ~ReadyNodeManager() {} Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)165 virtual Status Init( 166 const std::unordered_map<const NodeDef*, NodeState>* node_map) { 167 return Status::OK(); 168 } 169 virtual void AddNode(const NodeDef* node) = 0; 170 virtual const NodeDef* GetCurrNode() = 0; 171 virtual void RemoveCurrNode() = 0; 172 virtual bool Empty() const = 0; 173 }; 174 175 class FIFOManager : public ReadyNodeManager { 176 public: FIFOManager()177 FIFOManager() : ReadyNodeManager() {} ~FIFOManager()178 ~FIFOManager() override {} AddNode(const NodeDef * node)179 void AddNode(const NodeDef* node) override { nodes_.push_back(node); } GetCurrNode()180 const NodeDef* GetCurrNode() override { 181 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; 182 return nodes_.front(); 183 } RemoveCurrNode()184 void RemoveCurrNode() override { nodes_.pop_front(); } Empty()185 bool Empty() const override { return nodes_.empty(); } 186 187 private: 188 std::list<const NodeDef*> nodes_; 189 }; 190 191 // The LIFOManager schedules nodes by returning the last one added to the 192 // scheduler. A node is executed and then its ready outputs are newly added to 193 // the scheduler, so the LIFOManager will return outputs to a node following 194 // that node's execution. 195 class LIFOManager : public ReadyNodeManager { 196 public: LIFOManager()197 LIFOManager() : ReadyNodeManager() {} ~LIFOManager()198 ~LIFOManager() override {} 199 void AddNode(const NodeDef* node) override; 200 const NodeDef* GetCurrNode() override; 201 void RemoveCurrNode() override; Empty()202 bool Empty() const override { return nodes_.empty(); } 203 204 private: 205 std::list<const NodeDef*> nodes_; 206 // Keep track of the current node being executed by saving its position. 207 // Necessary because nodes may be added to the end of the list while a node is 208 // executing, and we want to remove the correct node (the one that is 209 // executing) rather than the new ones being added. 210 std::list<const NodeDef*>::iterator curr_pos_ = nodes_.end(); 211 }; 212 213 // Abstract class that maintains a heap/priority queue for scheduling ready 214 // nodes. Derived class needs to implement the Greater() function which returns 215 // the comparator for the heap. 216 class HeapReadyManager : public ReadyNodeManager { 217 public: 218 HeapReadyManager(); 219 Status Init( 220 const std::unordered_map<const NodeDef*, NodeState>* node_map) override; ~HeapReadyManager()221 ~HeapReadyManager() override {} 222 void AddNode(const NodeDef* node) override; 223 const NodeDef* GetCurrNode() override; 224 void RemoveCurrNode() override; 225 bool Empty() const override; 226 227 protected: 228 virtual std::function<bool(const NodeDef*, const NodeDef*)> Greater() = 0; 229 230 // nodes_ is the main queue, where we construct heap, and the front is the 231 // current node. 232 std::vector<const NodeDef*> nodes_; 233 234 // Comparator functor for heap; stl heap is max heap, so we use "greater than" 235 // functor for keeping the smallest time_ready node at the front of heap. 236 std::function<bool(const NodeDef*, const NodeDef*)> greater_; 237 238 // NodeState structure from SchedulerState to get time_ready of ready nodes. 239 // Not owned by FirstReadyManager. 240 const std::unordered_map<const NodeDef*, NodeState>* node_map_; 241 242 // Cached curr node. Set back to nullptr from RemoveCurrNode(). 243 const NodeDef* curr_node_; 244 }; 245 246 // FirstReadyManager picks a node with the minimum time_ready value. 247 // Behavior is deterministic when there are more than one nodes with the minimum 248 // time_ready value with unique node names as the tie-breaker. 249 class FirstReadyManager : public HeapReadyManager { 250 public: FirstReadyManager()251 FirstReadyManager() : HeapReadyManager() {} ~FirstReadyManager()252 ~FirstReadyManager() override {} 253 254 protected: 255 std::function<bool(const NodeDef*, const NodeDef*)> Greater() override; 256 }; 257 258 // PriorityReadyManager uses the given node priorities when picking up next node 259 // from all the ready nodes. 260 class PriorityReadyManager : public HeapReadyManager { 261 public: PriorityReadyManager()262 PriorityReadyManager() : HeapReadyManager() {} ~PriorityReadyManager()263 ~PriorityReadyManager() override {} 264 void AddNode(const NodeDef* node) override; 265 266 // Note this should be called after Init(). 267 Status SetPriority(const std::unordered_map<string, int>& node_priority); 268 269 protected: 270 std::function<bool(const NodeDef*, const NodeDef*)> Greater() override; 271 272 private: 273 // A map from unique node name to priority. Lower number means higher 274 // priority. 275 std::unordered_map<string, int> node_priority_; 276 }; 277 278 // CompositeNodeManager has a few other NodeManagers: per-device LIFO for normal 279 // ops (neither _Send nor _Recv) and FirstReadyManagers for _Send ops and _Recv 280 // ops, and then it chooses FirstReady among the ops chosen from each 281 // internal NodeManagers. The objective is to maximize producer-consumer 282 // locality within device, while processing nodes across devices, including 283 // _Send and _Recv, fairly, in terms of their time_ready. 284 class CompositeNodeManager : public ReadyNodeManager { 285 public: 286 CompositeNodeManager(); ~CompositeNodeManager()287 ~CompositeNodeManager() override {} 288 289 Status Init( 290 const std::unordered_map<const NodeDef*, NodeState>* node_map) override; 291 void AddNode(const NodeDef* node) override; 292 const NodeDef* GetCurrNode() override; 293 void RemoveCurrNode() override; 294 bool Empty() const override; 295 296 private: 297 // Internal ready node managers: 298 // LIFO for normal ops to maximize producer consumer locality. 299 // One LIFO per device. 300 std::unordered_map<string, LIFOManager> ops_lifo_map_; 301 // FirstReady for send and recv. Handle send and recv separately ensures that 302 // send and recv do not block previously read ops with LIFO schedule. 303 FirstReadyManager send_manager_; 304 FirstReadyManager recv_manager_; 305 306 // NodeState structure from SchedulerState to get time_ready of ready nodes. 307 // Not owned by CompositeReadyManager. 308 const std::unordered_map<const NodeDef*, NodeState>* node_map_; 309 310 // Cached curr node. Set back to nullptr from RemoveCurrNode(). 311 const NodeDef* curr_node_; 312 }; 313 314 // Constructs a ready node manager from the given string. 315 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory( 316 const string& ready_node_manager); 317 318 // Encapsulates all of the various pieces uses to track state of a scheduler; 319 // enables reuse of all scheduler state-related utilities across different 320 // scheduler implementations. 321 class SchedulerState { 322 public: 323 SchedulerState(const bool use_static_shapes, 324 const bool use_aggressive_shape_inference, Cluster* cluster, 325 std::unique_ptr<VirtualPlacer> placer); 326 // Move constructor. Explicitly defined because it otherwise gets implicitly 327 // deleted. SchedulerState is a move-only class, as we have a <unique_ptr> 328 // for it in VirtualScheduler. A derivative of VirtualScheduler can move a 329 // <unique_ptr> SchedulerState to VirtualScheduler when it is constructed, 330 // which is where this move constructor is needed. 331 SchedulerState(SchedulerState&& arg) = default; 332 // We explicitly delete assinment and copy operators, this is done implicitly, 333 // but we state it here explicitly for clarity. 334 SchedulerState& operator=(SchedulerState&& arg) = delete; 335 SchedulerState(const SchedulerState&) = delete; 336 SchedulerState& operator=(const SchedulerState&) = delete; 337 // Destructor. Must be defined such that a derivative class can override it 338 // and allow proper desctruction of the derivative class. If this is not done 339 // properly, memory leaks can occur. 340 virtual ~SchedulerState(); 341 // Sets up the graph while also performing some necessary transformations 342 // initial_nodes is the set of nodes (primary inputs) discovered by Init() 343 // which may be added by a ReadyNodeManager (or related/derivative scheduler) 344 // to begin node schedule and graph simulation. 345 Status Init(const GrapplerItem* item, 346 std::vector<const NodeDef*>* initial_nodes, 347 bool create_explicit_channel_device = true); 348 349 virtual Costs Summary() const; 350 // Like the above, but writes detailed stats to RunMetadata. 351 // If metadata is nullptr, then just calls and return Summary(). 352 virtual Costs Summary(RunMetadata* metadata); 353 // Generates RunMetadata's step_stats and partition_graphs fields from results 354 // of the virtual execution of the graph. 355 // TODO(rdegruijl) See if we can make this function and caller Summary() 356 // const. 357 void GenerateRunMetadata(RunMetadata* metadata); 358 359 // Returns per device memory usage. 360 const std::unordered_map<string, int64> GetPeakMemoryUsage() const; 361 const std::unordered_map<string, int64> GetPersistentMemoryUsage() const; enable_mem_usage_tracking()362 void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; } 363 // Returns (read only) device and node states. GetDeviceStates()364 const std::unordered_map<string, DeviceState>* GetDeviceStates() const { 365 return &device_; 366 } 367 GetNodeStates()368 const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const { 369 return &node_map_; 370 } 371 372 OpContext CreateOpContext(const NodeDef* node) const; 373 std::vector<const NodeDef*> MarkNodeExecuted(const NodeDef* node, 374 const Costs& node_costs, 375 const OpContext& op_context); 376 377 // Some getter functions. GetGrapplerItem()378 const GrapplerItem* GetGrapplerItem() { return grappler_item_; } GetGraphCost()379 Costs GetGraphCost() { return graph_costs_; } GetCluster()380 Cluster* GetCluster() { return cluster_; } GetUseStaticShape()381 bool GetUseStaticShape() { return use_static_shapes_; } GetUseAggressiveShapeInference()382 bool GetUseAggressiveShapeInference() { 383 return use_aggressive_shape_inference_; 384 } GetNodeMap()385 const std::unordered_map<const NodeDef*, NodeState>& GetNodeMap() { 386 return node_map_; 387 } 388 389 protected: 390 // Assigns the time_scheduled in the NodeState of node to the current 391 // execution_time of the device executing this node. 392 void SetNodeStateTimeScheduled(const NodeDef* node); 393 394 // This method can be used by a class derived from SchedulerState to 395 // access the device state map. GetMutableDeviceState()396 std::unordered_map<string, DeviceState>* GetMutableDeviceState() { 397 return &device_; 398 } 399 400 private: 401 // Methods called from Init(). Fails if initialize_ is set. 402 403 void MaybeUpdateInputOutput(const NodeDef* node); 404 NodeState& GetNodeStateOrCreateIt(const NodeDef* node); 405 // Creates a Send_ and Recv_ pair between from and to. The argument 406 // create_channel_device tells the function to create an explicit device for 407 // the channel. 408 std::pair<const NodeDef*, const NodeDef*> CreateSendRecv( 409 const NodeDef* from, const NodeDef* to, const NodeDef* input_node, 410 const string& input_name, bool create_channel_device); 411 string DeviceName(const NodeDef* node) const; 412 string SanitizedDeviceName(const NodeDef* node) const; 413 string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const; 414 415 // Helper methods. 416 void GetOutputNodes(const NodeDef* node, const Costs::Duration& curr_time, 417 std::vector<const NodeDef*>* output_nodes); 418 419 std::unordered_map<const NodeDef*, NodeState> node_map_; 420 std::unordered_map<string, DeviceState> device_; 421 422 // Pool of NodeDefs for SendRecv and Identity ops created. 423 std::vector<std::unique_ptr<NodeDef>> additional_nodes_; 424 425 // Stats: 426 // Op counts with key with input shape. 427 // Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]" 428 std::map<string, int> op_counts_; 429 // Individual op costs with key with input shape. 430 // Integer field for execution time in micro seconds. 431 // Boolean field for whether the cost is accurate. 432 std::map<string, std::pair<int, bool>> op_costs_; 433 434 Costs graph_costs_; // Graph cost. 435 std::map<string, Costs> op_to_cost_; // Per-op cost. 436 437 // Auxiliary data structures for constructing NodeState and DeviceState. 438 std::unique_ptr<GraphProperties> graph_properties_; // Initialized in Init(). 439 Cluster* cluster_; // Not owned. 440 const GrapplerItem* grappler_item_; // Not owned. 441 bool use_static_shapes_; 442 bool initialized_; 443 bool track_mem_usage_snapshot_; 444 const bool use_aggressive_shape_inference_; 445 std::unique_ptr<VirtualPlacer> placer_; 446 }; 447 448 // The virtual scheduler emulates execution of nodes in a graph, considering 449 // dependencies, device, etc. 450 class VirtualScheduler { 451 public: 452 // Does not take ownership of cluster or ready_nodes. 453 VirtualScheduler(const bool use_static_shapes, 454 const bool use_aggressive_shape_inference, Cluster* cluster, 455 ReadyNodeManager* ready_nodes, 456 std::unique_ptr<VirtualPlacer> placer); 457 // This constructor can be called by a derivative of VirtualScheduler to 458 // construct the base class. It lets VirtualScheduler take ownership of 459 // a new SchedulerState or a derivative thereof. 460 // Note that this constructor does not set a VirtualPlacer, in this 461 // constructor the VirtialPlacer is passed as a member of the SchedulerState 462 // that is passed as an argument. 463 VirtualScheduler(ReadyNodeManager* ready_nodes, 464 std::unique_ptr<SchedulerState> scheduler_state); 465 virtual ~VirtualScheduler(); 466 467 // Initializes the scheduler for the specific grappler item. 468 // Should be called immediately after the c'tor or when the scheduler will be 469 // reused for a new grappler item. All internal states of the scheduler 470 // related to the previous grappler item will be reset/cleared. 471 // 472 // This function should be called at least once after the scheduler is 473 // constructed. An uninitialized or failed-to-initialize scheduler will cause 474 // undefined behavior. 475 virtual Status Init(const GrapplerItem* item); 476 477 // Gets the current scheduled node for execution; the caller of this function 478 // can accordingly simulate the execution of the current scheduled node. 479 virtual OpContext GetCurrNode(); 480 // Marks the current scheduled node as executed. Note that we should call this 481 // function only after the execution of the node has been simulated; 482 // node_costs_ capture the simulated costs of the node. 483 // Returns true if there is any node to be scheduled. 484 virtual bool MarkCurrNodeExecuted(const Costs& node_costs); 485 486 // Prints out summary of execution (timing, memory usage, etc.) Summary()487 Costs Summary() const { return scheduler_state_->Summary(); } 488 // Like the above, but writes detailed stats to RunMetadata. 489 // If metadata is nullptr, then just calls and return Summary(). Summary(RunMetadata * metadata)490 Costs Summary(RunMetadata* metadata) { 491 return scheduler_state_->Summary(metadata); 492 } 493 // Generates RunMetadata's step_stats and partition_graphs fields from results 494 // of the virtual execution of the graph. GenerateRunMetadata(RunMetadata * metadata)495 void GenerateRunMetadata(RunMetadata* metadata) { 496 scheduler_state_->GenerateRunMetadata(metadata); 497 } 498 // Returns per device memory usage. GetPeakMemoryUsage()499 const std::unordered_map<string, int64> GetPeakMemoryUsage() const { 500 return scheduler_state_->GetPeakMemoryUsage(); 501 } GetPersistentMemoryUsage()502 const std::unordered_map<string, int64> GetPersistentMemoryUsage() const { 503 return scheduler_state_->GetPersistentMemoryUsage(); 504 } 505 // Returns VirtualScheduler (read only) device and node states. GetDeviceStates()506 const std::unordered_map<string, DeviceState>* GetDeviceStates() const { 507 return scheduler_state_->GetDeviceStates(); 508 } GetNodeStates()509 const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const { 510 return scheduler_state_->GetNodeStates(); 511 } enable_mem_usage_tracking()512 void enable_mem_usage_tracking() { 513 scheduler_state_->enable_mem_usage_tracking(); 514 } 515 516 protected: 517 // The state of the scheduler and the execution of the graph is encapsulated 518 // by the scheduler_state_ object. 519 std::unique_ptr<SchedulerState> scheduler_state_; 520 // ready_nodes_ is responsible for ordering the traversal of the graph. 521 ReadyNodeManager* ready_nodes_; // Not owned. 522 }; 523 524 } // namespace grappler 525 } // end namespace tensorflow 526 527 #endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 528