• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
18 
19 #include <list>
20 #include <memory>
21 #include <unordered_map>
22 #include <unordered_set>
23 
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/step_stats.pb.h"
26 #include "tensorflow/core/grappler/costs/cost_estimator.h"
27 #include "tensorflow/core/grappler/costs/graph_properties.h"
28 #include "tensorflow/core/grappler/costs/op_context.h"
29 #include "tensorflow/core/grappler/costs/virtual_placer.h"
30 #include "tensorflow/core/grappler/grappler_item.h"
31 
32 namespace tensorflow {
33 namespace grappler {
34 
35 struct NodeState {
36   // A node (i.e., an op) takes a set of input:port pairs and produces
37   // a set of output ports.
38 
39   // Cross references to input and output nodes from graphdef.
40   std::vector<std::pair<const NodeDef*, int>> inputs;  // Input, port pairs.
41   // List of output nodes (a list of nodes that takes this output port as input)
42   // keyed by port_num. Note that port_num -1 is used for control dependency.
43   std::unordered_map<int, std::vector<const NodeDef*>> outputs;
44 
45   // Info from GraphProperties.
46   std::vector<OpInfo::TensorProperties> input_properties;
47   std::vector<OpInfo::TensorProperties> output_properties;
48 
49   // Canonical device name used within VirtualScheduler.
50   string device_name;
51 
52   // States updated as scheduling nodes.
53   int num_inputs_ready;
54   std::unordered_map<int, int> num_outputs_executed;
55   Costs::Duration time_ready;
56   Costs::Duration time_scheduled;
57   Costs::Duration time_finished;
58   // Time that all the consumers are executed (hence, no need to keep this
59   // output in memory), keyed by port_num.
60   std::unordered_map<int, Costs::Duration> time_no_references;
61 
62   // Note that a node may have multiple output ports. The length of outputs,
63   // num_outputs_executed, and time_no_references should be
64   // identical when a NodeState is fully initialized.
65   // They should be 1 + output_properties.size() as we add [-1] for control
66   // dependency.
67 
68   // Node will be ready to be executed at time_ready, scheduled at
69   // time_scheduled, and finishes execution at time_finished.
70   // Each output port uses up memory space from time_scheduled to its
71   // time_no_references.
72 
73   // How many times this node has been executed, e.g. in a while loop.
74   int execution_count;
75 
NodeStateNodeState76   NodeState() {
77     num_inputs_ready = 0;
78     time_ready = Costs::Duration::max();
79     time_scheduled = Costs::Duration::max();
80     time_finished = Costs::Duration::max();
81     execution_count = 0;
82     // Note that num_outputs_executed and time_no_references are not initialized
83     // here, since we don't know the size (i.e., # outputs for this node).
84   }
85 };
86 
87 struct DeviceState {
88   // Nodes executed on this device in execution order.
89   std::vector<const NodeDef*> nodes_executed;
90 
91   struct NodePairHash {
92    public:
operatorDeviceState::NodePairHash93     const std::size_t operator()(
94         const std::pair<const NodeDef*, int>& element) const {
95       return std::hash<const NodeDef*>()(element.first);
96     }
97   };
98 
99   // Nodes currently allocated in memory: set of NodeDef* and port_num pairs
100   // so that we can track which output of the node is in memory.
101   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
102       nodes_in_memory;
103 
104   // Nodes allocated in memory persistently: e.g., Variables.
105   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
106       persistent_nodes;
107 
108   // Snapshot of nodes_in_memory, when memory usage is at peak.
109   // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs.
110   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
111       mem_usage_snapshot_at_peak;
112 
113   Costs device_costs;
114   std::map<string, Costs> op_to_cost;  // Per-op cost.
115 
116   int64 memory_usage;      // Current temporary memory usage
117   int64 max_memory_usage;  // Max temporary memory usage
118 
DeviceStateDeviceState119   DeviceState() {
120     device_costs = Costs::ZeroCosts();
121     device_costs.num_ops_total = 0;
122     memory_usage = 0;
123     max_memory_usage = 0;
124   }
125 
GetCurrTimeDeviceState126   Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
127 };
128 
129 // ReadyNodeManager (abstract class):
130 // Keeps ready nodes and picks the best one to be scheduled.
131 class ReadyNodeManager {
132  public:
ReadyNodeManager()133   ReadyNodeManager() {}
~ReadyNodeManager()134   virtual ~ReadyNodeManager() {}
Init(const std::unordered_map<const NodeDef *,NodeState> * node_state)135   virtual void Init(
136       const std::unordered_map<const NodeDef*, NodeState>* node_state) {}
137   virtual void AddNode(const NodeDef* node) = 0;
138   virtual const NodeDef* GetCurrNode() = 0;
139   virtual void RemoveCurrNode() = 0;
140   virtual bool Empty() const = 0;
141 };
142 
143 class FIFOManager : public ReadyNodeManager {
144  public:
FIFOManager()145   FIFOManager() : ReadyNodeManager() {}
~FIFOManager()146   ~FIFOManager() override {}
Init(const std::unordered_map<const NodeDef *,NodeState> * node_state)147   void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
148       override {}
AddNode(const NodeDef * node)149   void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
GetCurrNode()150   const NodeDef* GetCurrNode() override {
151     CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
152     return nodes_.front();
153   }
RemoveCurrNode()154   void RemoveCurrNode() override { nodes_.pop_front(); }
Empty()155   bool Empty() const override { return nodes_.empty(); }
156 
157  private:
158   std::list<const NodeDef*> nodes_;
159 };
160 
161 // The LIFOManager schedules nodes by returning the last one added to the
162 // scheduler. A node is executed and then its ready outputs are newly added to
163 // the scheduler, so the LIFOManager will return outputs to a node following
164 // that node's execution.
165 class LIFOManager : public ReadyNodeManager {
166  public:
LIFOManager()167   LIFOManager() : ReadyNodeManager() {}
~LIFOManager()168   ~LIFOManager() override {}
Init(const std::unordered_map<const NodeDef *,NodeState> * node_state)169   void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
170       override {}
AddNode(const NodeDef * node)171   void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
172   const NodeDef* GetCurrNode() override;
173   void RemoveCurrNode() override;
Empty()174   bool Empty() const override { return nodes_.empty(); }
175 
176  private:
177   std::list<const NodeDef*> nodes_;
178   // Keep track of the current node being executed by saving its position.
179   // Necessary because nodes may be added to the end of the list while a node is
180   // executing, and we want to remove the correct node (the one that is
181   // executing) rather than the new ones being added.
182   std::list<const NodeDef*>::iterator curr_pos_ = nodes_.end();
183 };
184 
185 // FirstReadyManager picks a node with the minimum time_ready value.
186 // Behavior is unknown if there are more than one nodes with the minimum
187 // time_ready value (it depends on C++ STL push_heap and pop_heap).
188 class FirstReadyManager : public ReadyNodeManager {
189  public:
190   FirstReadyManager();
191   void Init(
192       const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
~FirstReadyManager()193   ~FirstReadyManager() override {}
AddNode(const NodeDef * node)194   void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); }
195   const NodeDef* GetCurrNode() override;
196   void RemoveCurrNode() override;
197   bool Empty() const override;
198 
199  private:
200   // Move all the nodes in the waiting_queue_ to nodes_.
201   void DrainWaitingQueue();
202 
203   // nodes_ is the main queue, where we construct heap, and the front is the
204   // current node.
205   std::vector<const NodeDef*> nodes_;
206   // Newly added nodes are added to waiting_queue_. That way, GetCurrNode(),
207   // which returns the front of the nodes_, always returns the same node,
208   // even if any of new nodes has time_ready smaller than the current node's.
209   std::vector<const NodeDef*> waiting_queue_;
210   // Comparator functor for heap; stl heap is max heap, so we use "greater than"
211   // functor for keeping the smallest time_ready node at the front of heap.
212   std::function<bool(const NodeDef*, const NodeDef*)> greater_;
213 
214   // NodeState structure from VirtualScheduler to get time_ready of ready nodes.
215   // Not owned by FirstReadyManager.
216   const std::unordered_map<const NodeDef*, NodeState>* node_state_;
217 };
218 
219 // CompositeNodeManager has a few other NodeManagers: per-device LIFO for normal
220 // ops (neither _Send nor _Recv) and FirstReadyManagers for _Send ops and _Recv
221 // ops, and then it chooses FirstReady among the ops chosen from each
222 // internal NodeManagers. The objective is to maximize producer-consumer
223 // locality within device, while processing nodes across devices, including
224 // _Send and _Recv, fairly, in terms of their time_ready.
225 class CompositeNodeManager : public ReadyNodeManager {
226  public:
227   CompositeNodeManager();
~CompositeNodeManager()228   ~CompositeNodeManager() override {}
229 
230   void Init(
231       const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
232   void AddNode(const NodeDef* node) override;
233   const NodeDef* GetCurrNode() override;
234   void RemoveCurrNode() override;
235   bool Empty() const override;
236 
237  private:
238   // Internal ready node managers:
239   // LIFO for normal ops to maximize producer consumer locality.
240   // One LIFO per device.
241   std::unordered_map<string, LIFOManager> ops_lifo_map_;
242   // FirstReady for send and recv. Handle send and recv separately ensures that
243   // send and recv do not block previously read ops with LIFO schedule.
244   FirstReadyManager send_manager_;
245   FirstReadyManager recv_manager_;
246 
247   // NodeState structure from VirtualScheduler to get time_ready of ready nodes.
248   // Not owned by FirstReadyManager.
249   const std::unordered_map<const NodeDef*, NodeState>* node_state_;
250 
251   // Cached curr node. Set back to nullptr from RemoveCurrNode().
252   const NodeDef* curr_node_;
253 };
254 
255 // Constructs a ready node manager from the given string.
256 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
257     const string& ready_node_manager);
258 
259 // The virtual scheduler emulates execution of nodes in a graph, considering
260 // dependencies, device, etc.
261 class VirtualScheduler {
262  public:
263   // Does not take ownership of cluster or ready_nodes.
264   VirtualScheduler(const bool use_static_shapes,
265                    const bool use_aggressive_shape_inference, Cluster* cluster,
266                    ReadyNodeManager* ready_nodes);
267   // Initializes the scheduler for the specific grappler item.
268   // Should be called immediately after the c'tor or when the scheduler will be
269   // reused for a new grappler item. All internal states of the scheduler
270   // related to the previous grappler item will be reset/cleared.
271   //
272   // This function should be called at least once after the scheduler is
273   // constructed. An uninitialized or failed-to-initialize scheduler will cause
274   // undefined behavior.
275   Status Init(const GrapplerItem* item);
276 
277   OpContext GetCurrNode() const;
278 
279   // Returns true if there is any node to be scheduled.
280   bool MarkCurrNodeExecuted(const Costs& node_costs);
281 
282   // Prints out summary of execution (timing, memory usage, etc.)
283   Costs Summary() const;
284   // Like the above, but writes detailed stats to RunMetadata.
285   // If metadata is nullptr, then just calls and return Summary().
286   Costs Summary(RunMetadata* metadata);
287   // Generate RunMetadata's step_stats and partition_graphs fields from results
288   // of the virtual execution of the graph.
289   void GenerateRunMetadata(RunMetadata* metadata);
290 
291   // Return per device peak memory usage.
292   const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
293 
GetDeviceStates()294   const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
295     return &device_;
296   }
GetNodeStates()297   const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
298     return &node_map_;
299   }
300 
enable_mem_usage_tracking()301   void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; }
302 
303  private:
304   // Constants.
305   const string kAttrInputSrc = "input_source_";
306   const string kAttrSrcDevice = "send_device";
307   const string kAttrDstDevice = "recv_device";
308   const string kAttrTensorName = "tensor_name";
309   const string kChannelDevice = "Channel";
310 
311   // Methods called from Init(). Fails if initialize_ is set.
312   void MaybeUpdateInputOutput(const NodeDef* node);
313   NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
314   std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
315       const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
316       const string& input_name);
317   string DeviceName(const NodeDef* node) const;
318   string SanitizedDeviceName(const NodeDef* node) const;
319   string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
320 
321   // Helper methods.
322   Costs& FindOrCreateZero(const string& op_name,
323                           std::map<string, Costs>* op_cost);
324   float Round2(const float x) const;
325   bool IsPersistentNode(const NodeDef* node) const;
326   void AddOutputNodesToReadyQueue(const NodeDef* node,
327                                   const Costs::Duration& curr_time);
328 
329   // Scheduler states:
330   ReadyNodeManager* ready_nodes_;  // Not owned.
331   std::unordered_map<const NodeDef*, NodeState> node_map_;
332   std::unordered_map<string, DeviceState> device_;
333 
334   // Pool of NodeDefs for SendRecv and Identity ops created.
335   std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
336 
337   // Stats:
338   // Op counts with key with input shape.
339   // Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]"
340   std::map<string, int> op_counts_;
341   // Individual op costs with key with input shape.
342   // Integer field for execution time in micro seconds.
343   // Boolean field for whether the cost is accurate.
344   std::map<string, std::pair<int, bool>> op_costs_;
345 
346   Costs graph_costs_;                   // Graph cost.
347   std::map<string, Costs> op_to_cost_;  // Per-op cost.
348 
349   // Auxiliary data structures for constructing NodeState and DeviceState.
350   std::unique_ptr<GraphProperties> graph_properties_;  // Initialized in Init().
351   Cluster* cluster_;                                   // Not owned.
352 
353   const GrapplerItem* grappler_item_;  // Not owned.
354   bool use_static_shapes_;
355   bool initialized_;
356   bool track_mem_usage_snapshot_;
357   const bool use_aggressive_shape_inference_;
358 
359   VirtualPlacer placer_;  // owned.
360 };
361 
362 }  // namespace grappler
363 }  // end namespace tensorflow
364 
365 #endif  // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
366