1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 18 19 #include <list> 20 #include <memory> 21 #include <unordered_map> 22 #include <unordered_set> 23 24 #include "tensorflow/core/framework/node_def.pb.h" 25 #include "tensorflow/core/framework/step_stats.pb.h" 26 #include "tensorflow/core/grappler/costs/cost_estimator.h" 27 #include "tensorflow/core/grappler/costs/graph_properties.h" 28 #include "tensorflow/core/grappler/costs/op_context.h" 29 #include "tensorflow/core/grappler/costs/virtual_placer.h" 30 #include "tensorflow/core/grappler/grappler_item.h" 31 32 namespace tensorflow { 33 namespace grappler { 34 35 struct NodeState { 36 // A node (i.e., an op) takes a set of input:port pairs and produces 37 // a set of output ports. 38 39 // Cross references to input and output nodes from graphdef. 40 std::vector<std::pair<const NodeDef*, int>> inputs; // Input, port pairs. 41 // List of output nodes (a list of nodes that takes this output port as input) 42 // keyed by port_num. Note that port_num -1 is used for control dependency. 43 std::unordered_map<int, std::vector<const NodeDef*>> outputs; 44 45 // Info from GraphProperties. 46 std::vector<OpInfo::TensorProperties> input_properties; 47 std::vector<OpInfo::TensorProperties> output_properties; 48 49 // Canonical device name used within VirtualScheduler. 50 string device_name; 51 52 // States updated as scheduling nodes. 53 int num_inputs_ready; 54 std::unordered_map<int, int> num_outputs_executed; 55 Costs::Duration time_ready; 56 Costs::Duration time_scheduled; 57 Costs::Duration time_finished; 58 // Time that all the consumers are executed (hence, no need to keep this 59 // output in memory), keyed by port_num. 60 std::unordered_map<int, Costs::Duration> time_no_references; 61 62 // Note that a node may have multiple output ports. The length of outputs, 63 // num_outputs_executed, and time_no_references should be 64 // identical when a NodeState is fully initialized. 65 // They should be 1 + output_properties.size() as we add [-1] for control 66 // dependency. 67 68 // Node will be ready to be executed at time_ready, scheduled at 69 // time_scheduled, and finishes execution at time_finished. 70 // Each output port uses up memory space from time_scheduled to its 71 // time_no_references. 72 73 // How many times this node has been executed, e.g. in a while loop. 74 int execution_count; 75 NodeStateNodeState76 NodeState() { 77 num_inputs_ready = 0; 78 time_ready = Costs::Duration::max(); 79 time_scheduled = Costs::Duration::max(); 80 time_finished = Costs::Duration::max(); 81 execution_count = 0; 82 // Note that num_outputs_executed and time_no_references are not initialized 83 // here, since we don't know the size (i.e., # outputs for this node). 84 } 85 }; 86 87 struct DeviceState { 88 // Nodes executed on this device in execution order. 89 std::vector<const NodeDef*> nodes_executed; 90 91 struct NodePairHash { 92 public: operatorDeviceState::NodePairHash93 const std::size_t operator()( 94 const std::pair<const NodeDef*, int>& element) const { 95 return std::hash<const NodeDef*>()(element.first); 96 } 97 }; 98 99 // Nodes currently allocated in memory: set of NodeDef* and port_num pairs 100 // so that we can track which output of the node is in memory. 101 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 102 nodes_in_memory; 103 104 // Nodes allocated in memory persistently: e.g., Variables. 105 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 106 persistent_nodes; 107 108 // Snapshot of nodes_in_memory, when memory usage is at peak. 109 // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs. 110 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 111 mem_usage_snapshot_at_peak; 112 113 Costs device_costs; 114 std::map<string, Costs> op_to_cost; // Per-op cost. 115 116 int64 memory_usage; // Current temporary memory usage 117 int64 max_memory_usage; // Max temporary memory usage 118 DeviceStateDeviceState119 DeviceState() { 120 device_costs = Costs::ZeroCosts(); 121 device_costs.num_ops_total = 0; 122 memory_usage = 0; 123 max_memory_usage = 0; 124 } 125 GetCurrTimeDeviceState126 Costs::Duration GetCurrTime() const { return device_costs.execution_time; } 127 }; 128 129 // ReadyNodeManager (abstract class): 130 // Keeps ready nodes and picks the best one to be scheduled. 131 class ReadyNodeManager { 132 public: ReadyNodeManager()133 ReadyNodeManager() {} ~ReadyNodeManager()134 virtual ~ReadyNodeManager() {} Init(const std::unordered_map<const NodeDef *,NodeState> * node_state)135 virtual void Init( 136 const std::unordered_map<const NodeDef*, NodeState>* node_state) {} 137 virtual void AddNode(const NodeDef* node) = 0; 138 virtual const NodeDef* GetCurrNode() = 0; 139 virtual void RemoveCurrNode() = 0; 140 virtual bool Empty() const = 0; 141 }; 142 143 class FIFOManager : public ReadyNodeManager { 144 public: FIFOManager()145 FIFOManager() : ReadyNodeManager() {} ~FIFOManager()146 ~FIFOManager() override {} Init(const std::unordered_map<const NodeDef *,NodeState> * node_state)147 void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state) 148 override {} AddNode(const NodeDef * node)149 void AddNode(const NodeDef* node) override { nodes_.push_back(node); } GetCurrNode()150 const NodeDef* GetCurrNode() override { 151 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; 152 return nodes_.front(); 153 } RemoveCurrNode()154 void RemoveCurrNode() override { nodes_.pop_front(); } Empty()155 bool Empty() const override { return nodes_.empty(); } 156 157 private: 158 std::list<const NodeDef*> nodes_; 159 }; 160 161 // The LIFOManager schedules nodes by returning the last one added to the 162 // scheduler. A node is executed and then its ready outputs are newly added to 163 // the scheduler, so the LIFOManager will return outputs to a node following 164 // that node's execution. 165 class LIFOManager : public ReadyNodeManager { 166 public: LIFOManager()167 LIFOManager() : ReadyNodeManager() {} ~LIFOManager()168 ~LIFOManager() override {} Init(const std::unordered_map<const NodeDef *,NodeState> * node_state)169 void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state) 170 override {} AddNode(const NodeDef * node)171 void AddNode(const NodeDef* node) override { nodes_.push_back(node); } 172 const NodeDef* GetCurrNode() override; 173 void RemoveCurrNode() override; Empty()174 bool Empty() const override { return nodes_.empty(); } 175 176 private: 177 std::list<const NodeDef*> nodes_; 178 // Keep track of the current node being executed by saving its position. 179 // Necessary because nodes may be added to the end of the list while a node is 180 // executing, and we want to remove the correct node (the one that is 181 // executing) rather than the new ones being added. 182 std::list<const NodeDef*>::iterator curr_pos_ = nodes_.end(); 183 }; 184 185 // FirstReadyManager picks a node with the minimum time_ready value. 186 // Behavior is unknown if there are more than one nodes with the minimum 187 // time_ready value (it depends on C++ STL push_heap and pop_heap). 188 class FirstReadyManager : public ReadyNodeManager { 189 public: 190 FirstReadyManager(); 191 void Init( 192 const std::unordered_map<const NodeDef*, NodeState>* node_state) override; ~FirstReadyManager()193 ~FirstReadyManager() override {} AddNode(const NodeDef * node)194 void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); } 195 const NodeDef* GetCurrNode() override; 196 void RemoveCurrNode() override; 197 bool Empty() const override; 198 199 private: 200 // Move all the nodes in the waiting_queue_ to nodes_. 201 void DrainWaitingQueue(); 202 203 // nodes_ is the main queue, where we construct heap, and the front is the 204 // current node. 205 std::vector<const NodeDef*> nodes_; 206 // Newly added nodes are added to waiting_queue_. That way, GetCurrNode(), 207 // which returns the front of the nodes_, always returns the same node, 208 // even if any of new nodes has time_ready smaller than the current node's. 209 std::vector<const NodeDef*> waiting_queue_; 210 // Comparator functor for heap; stl heap is max heap, so we use "greater than" 211 // functor for keeping the smallest time_ready node at the front of heap. 212 std::function<bool(const NodeDef*, const NodeDef*)> greater_; 213 214 // NodeState structure from VirtualScheduler to get time_ready of ready nodes. 215 // Not owned by FirstReadyManager. 216 const std::unordered_map<const NodeDef*, NodeState>* node_state_; 217 }; 218 219 // CompositeNodeManager has a few other NodeManagers: per-device LIFO for normal 220 // ops (neither _Send nor _Recv) and FirstReadyManagers for _Send ops and _Recv 221 // ops, and then it chooses FirstReady among the ops chosen from each 222 // internal NodeManagers. The objective is to maximize producer-consumer 223 // locality within device, while processing nodes across devices, including 224 // _Send and _Recv, fairly, in terms of their time_ready. 225 class CompositeNodeManager : public ReadyNodeManager { 226 public: 227 CompositeNodeManager(); ~CompositeNodeManager()228 ~CompositeNodeManager() override {} 229 230 void Init( 231 const std::unordered_map<const NodeDef*, NodeState>* node_state) override; 232 void AddNode(const NodeDef* node) override; 233 const NodeDef* GetCurrNode() override; 234 void RemoveCurrNode() override; 235 bool Empty() const override; 236 237 private: 238 // Internal ready node managers: 239 // LIFO for normal ops to maximize producer consumer locality. 240 // One LIFO per device. 241 std::unordered_map<string, LIFOManager> ops_lifo_map_; 242 // FirstReady for send and recv. Handle send and recv separately ensures that 243 // send and recv do not block previously read ops with LIFO schedule. 244 FirstReadyManager send_manager_; 245 FirstReadyManager recv_manager_; 246 247 // NodeState structure from VirtualScheduler to get time_ready of ready nodes. 248 // Not owned by FirstReadyManager. 249 const std::unordered_map<const NodeDef*, NodeState>* node_state_; 250 251 // Cached curr node. Set back to nullptr from RemoveCurrNode(). 252 const NodeDef* curr_node_; 253 }; 254 255 // Constructs a ready node manager from the given string. 256 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory( 257 const string& ready_node_manager); 258 259 // The virtual scheduler emulates execution of nodes in a graph, considering 260 // dependencies, device, etc. 261 class VirtualScheduler { 262 public: 263 // Does not take ownership of cluster or ready_nodes. 264 VirtualScheduler(const bool use_static_shapes, 265 const bool use_aggressive_shape_inference, Cluster* cluster, 266 ReadyNodeManager* ready_nodes); 267 // Initializes the scheduler for the specific grappler item. 268 // Should be called immediately after the c'tor or when the scheduler will be 269 // reused for a new grappler item. All internal states of the scheduler 270 // related to the previous grappler item will be reset/cleared. 271 // 272 // This function should be called at least once after the scheduler is 273 // constructed. An uninitialized or failed-to-initialize scheduler will cause 274 // undefined behavior. 275 Status Init(const GrapplerItem* item); 276 277 OpContext GetCurrNode() const; 278 279 // Returns true if there is any node to be scheduled. 280 bool MarkCurrNodeExecuted(const Costs& node_costs); 281 282 // Prints out summary of execution (timing, memory usage, etc.) 283 Costs Summary() const; 284 // Like the above, but writes detailed stats to RunMetadata. 285 // If metadata is nullptr, then just calls and return Summary(). 286 Costs Summary(RunMetadata* metadata); 287 // Generate RunMetadata's step_stats and partition_graphs fields from results 288 // of the virtual execution of the graph. 289 void GenerateRunMetadata(RunMetadata* metadata); 290 291 // Return per device peak memory usage. 292 const std::unordered_map<string, int64> GetPeakMemoryUsage() const; 293 GetDeviceStates()294 const std::unordered_map<string, DeviceState>* GetDeviceStates() const { 295 return &device_; 296 } GetNodeStates()297 const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const { 298 return &node_map_; 299 } 300 enable_mem_usage_tracking()301 void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; } 302 303 private: 304 // Constants. 305 const string kAttrInputSrc = "input_source_"; 306 const string kAttrSrcDevice = "send_device"; 307 const string kAttrDstDevice = "recv_device"; 308 const string kAttrTensorName = "tensor_name"; 309 const string kChannelDevice = "Channel"; 310 311 // Methods called from Init(). Fails if initialize_ is set. 312 void MaybeUpdateInputOutput(const NodeDef* node); 313 NodeState& GetNodeStateOrCreateIt(const NodeDef* node); 314 std::pair<const NodeDef*, const NodeDef*> CreateSendRecv( 315 const NodeDef* from, const NodeDef* to, const NodeDef* input_node, 316 const string& input_name); 317 string DeviceName(const NodeDef* node) const; 318 string SanitizedDeviceName(const NodeDef* node) const; 319 string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const; 320 321 // Helper methods. 322 Costs& FindOrCreateZero(const string& op_name, 323 std::map<string, Costs>* op_cost); 324 float Round2(const float x) const; 325 bool IsPersistentNode(const NodeDef* node) const; 326 void AddOutputNodesToReadyQueue(const NodeDef* node, 327 const Costs::Duration& curr_time); 328 329 // Scheduler states: 330 ReadyNodeManager* ready_nodes_; // Not owned. 331 std::unordered_map<const NodeDef*, NodeState> node_map_; 332 std::unordered_map<string, DeviceState> device_; 333 334 // Pool of NodeDefs for SendRecv and Identity ops created. 335 std::vector<std::unique_ptr<NodeDef>> additional_nodes_; 336 337 // Stats: 338 // Op counts with key with input shape. 339 // Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]" 340 std::map<string, int> op_counts_; 341 // Individual op costs with key with input shape. 342 // Integer field for execution time in micro seconds. 343 // Boolean field for whether the cost is accurate. 344 std::map<string, std::pair<int, bool>> op_costs_; 345 346 Costs graph_costs_; // Graph cost. 347 std::map<string, Costs> op_to_cost_; // Per-op cost. 348 349 // Auxiliary data structures for constructing NodeState and DeviceState. 350 std::unique_ptr<GraphProperties> graph_properties_; // Initialized in Init(). 351 Cluster* cluster_; // Not owned. 352 353 const GrapplerItem* grappler_item_; // Not owned. 354 bool use_static_shapes_; 355 bool initialized_; 356 bool track_mem_usage_snapshot_; 357 const bool use_aggressive_shape_inference_; 358 359 VirtualPlacer placer_; // owned. 360 }; 361 362 } // namespace grappler 363 } // end namespace tensorflow 364 365 #endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 366