• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
18 
19 #include <list>
20 #include <memory>
21 #include <unordered_map>
22 #include <unordered_set>
23 
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/step_stats.pb.h"
26 #include "tensorflow/core/grappler/costs/cost_estimator.h"
27 #include "tensorflow/core/grappler/costs/graph_properties.h"
28 #include "tensorflow/core/grappler/costs/op_context.h"
29 #include "tensorflow/core/grappler/costs/virtual_placer.h"
30 #include "tensorflow/core/grappler/grappler_item.h"
31 
32 namespace tensorflow {
33 namespace grappler {
34 
35 struct NodeState {
36   // A node (i.e., an op) takes a set of input:port pairs and produces
37   // a set of output ports.
38 
39   // Cross references to input and output nodes from graphdef.
40   std::vector<std::pair<const NodeDef*, int>> inputs;  // Input, port pairs.
41   // List of output nodes (a list of nodes that takes this output port as input)
42   // keyed by port_num. Note that port_num -1 is used for control dependency.
43   std::unordered_map<int, std::vector<const NodeDef*>> outputs;
44 
45   // Info from GraphProperties.
46   std::vector<OpInfo::TensorProperties> input_properties;
47   std::vector<OpInfo::TensorProperties> output_properties;
48 
49   // Canonical device name used within VirtualScheduler.
50   string device_name;
51 
52   // States updated as scheduling nodes.
53   int num_inputs_ready;
54   std::unordered_map<int, int> num_outputs_executed;
55   Costs::Duration time_ready;
56   Costs::Duration time_scheduled;
57   Costs::Duration time_finished;
58   // Time that all the consumers are executed (hence, no need to keep this
59   // output in memory), keyed by port_num.
60   std::unordered_map<int, Costs::Duration> time_no_references;
61 
62   // Note that a node may have multiple output ports. The length of outputs,
63   // num_outputs_executed, and time_no_references should be
64   // identical when a NodeState is fully initialized.
65   // They should be 1 + output_properties.size() as we add [-1] for control
66   // dependency.
67 
68   // Node will be ready to be executed at time_ready, scheduled at
69   // time_scheduled, and finishes execution at time_finished.
70   // Each output port uses up memory space from time_scheduled to its
71   // time_no_references.
72 
73   Costs node_costs;  // Node costs per execution
TotalNodeCostsNodeState74   Costs TotalNodeCosts() const {
75     return MultiplyCosts(node_costs, execution_count);
76   }
77   // How many times this node has been executed, e.g. in a while loop.
78   int execution_count;
79 
80   // Output shape incompatible between shape annotation and shape inference.
81   bool shape_incompatible;
82 
NodeStateNodeState83   NodeState() {
84     num_inputs_ready = 0;
85     time_ready = Costs::Duration::max();
86     time_scheduled = Costs::Duration::max();
87     time_finished = Costs::Duration::max();
88     execution_count = 0;
89     shape_incompatible = false;
90     // Note that num_outputs_executed and time_no_references are not initialized
91     // here, since we don't know the size (i.e., # outputs for this node).
92   }
93 };
94 
95 struct DeviceState {
96   // Nodes executed on this device in execution order.
97   std::vector<const NodeDef*> nodes_executed;
98 
99   struct NodePairHash {
100    public:
operatorDeviceState::NodePairHash101     const std::size_t operator()(
102         const std::pair<const NodeDef*, int>& element) const {
103       return std::hash<const NodeDef*>()(element.first);
104     }
105   };
106 
107   // Nodes currently allocated in memory: set of NodeDef* and port_num pairs
108   // so that we can track which output of the node is in memory.
109   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
110       nodes_in_memory;
111 
112   // Nodes allocated in memory persistently: e.g., Variables.
113   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
114       persistent_nodes;
115 
116   // Snapshot of nodes_in_memory, when memory usage is at peak.
117   // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs.
118   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
119       mem_usage_snapshot_at_peak;
120 
121   Costs device_costs;
122   std::map<string, Costs> op_to_cost;  // Per-op cost.
123 
124   int64 memory_usage;      // Current temporary memory usage
125   int64 max_memory_usage;  // Max temporary memory usage
126 
127   // Shape annotation statistics.
128   struct ShapeAnnotationStats {
129     // Number of ops with shape annotated.
130     int64 num_ops_annotated = 0;
131     // Number of ops executed multiple times (e.g. in a loop).
132     int64 num_ops_executed_more_than_once = 0;
133     // Number of ops executed: account for execution count.
134     int64 num_ops_executed = 0;
135     // Number of ops with dynamic shapes (e.g. shape changes in a loop).
136     int64 num_ops_with_dynamic_shapes = 0;
137     // Number of ops with incompatible shapes between annotation and shape
138     // inference.
139     int64 num_ops_with_incompatible_shapes = 0;
140   } shape_annotation_stats;
141 
DeviceStateDeviceState142   DeviceState() {
143     device_costs = Costs::ZeroCosts();
144     device_costs.num_ops_total = 0;
145     memory_usage = 0;
146     max_memory_usage = 0;
147   }
148 
GetCurrTimeDeviceState149   Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
150 };
151 
152 // ReadyNodeManager (abstract class):
153 // Keeps ready nodes and picks the best one to be scheduled.
154 class ReadyNodeManager {
155  public:
ReadyNodeManager()156   ReadyNodeManager() {}
~ReadyNodeManager()157   virtual ~ReadyNodeManager() {}
Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)158   virtual Status Init(
159       const std::unordered_map<const NodeDef*, NodeState>* node_map) {
160     return Status::OK();
161   }
162   virtual void AddNode(const NodeDef* node) = 0;
163   virtual const NodeDef* GetCurrNode() = 0;
164   virtual void RemoveCurrNode() = 0;
165   virtual bool Empty() const = 0;
166 };
167 
168 class FIFOManager : public ReadyNodeManager {
169  public:
FIFOManager()170   FIFOManager() : ReadyNodeManager() {}
~FIFOManager()171   ~FIFOManager() override {}
AddNode(const NodeDef * node)172   void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
GetCurrNode()173   const NodeDef* GetCurrNode() override {
174     CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
175     return nodes_.front();
176   }
RemoveCurrNode()177   void RemoveCurrNode() override { nodes_.pop_front(); }
Empty()178   bool Empty() const override { return nodes_.empty(); }
179 
180  private:
181   std::list<const NodeDef*> nodes_;
182 };
183 
184 // The LIFOManager schedules nodes by returning the last one added to the
185 // scheduler. A node is executed and then its ready outputs are newly added to
186 // the scheduler, so the LIFOManager will return outputs to a node following
187 // that node's execution.
188 class LIFOManager : public ReadyNodeManager {
189  public:
LIFOManager()190   LIFOManager() : ReadyNodeManager() {}
~LIFOManager()191   ~LIFOManager() override {}
AddNode(const NodeDef * node)192   void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
193   const NodeDef* GetCurrNode() override;
194   void RemoveCurrNode() override;
Empty()195   bool Empty() const override { return nodes_.empty(); }
196 
197  private:
198   std::list<const NodeDef*> nodes_;
199   // Keep track of the current node being executed by saving its position.
200   // Necessary because nodes may be added to the end of the list while a node is
201   // executing, and we want to remove the correct node (the one that is
202   // executing) rather than the new ones being added.
203   std::list<const NodeDef*>::iterator curr_pos_ = nodes_.end();
204 };
205 
206 // Abstract class that maintains a heap/priority queue for scheduling ready
207 // nodes. Derived class needs to implement the Greater() function which returns
208 // the comparator for the heap.
209 class HeapReadyManager : public ReadyNodeManager {
210  public:
211   HeapReadyManager();
212   Status Init(
213       const std::unordered_map<const NodeDef*, NodeState>* node_map) override;
~HeapReadyManager()214   ~HeapReadyManager() override {}
AddNode(const NodeDef * node)215   void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); }
216   const NodeDef* GetCurrNode() override;
217   void RemoveCurrNode() override;
218   bool Empty() const override;
219 
220  protected:
221   virtual std::function<bool(const NodeDef*, const NodeDef*)> Greater() = 0;
222   // Move all the nodes in the waiting_queue_ to nodes_.
223   void DrainWaitingQueue();
224 
225   // nodes_ is the main queue, where we construct heap, and the front is the
226   // current node.
227   std::vector<const NodeDef*> nodes_;
228   // Newly added nodes are added to waiting_queue_. That way, GetCurrNode(),
229   // which returns the front of the nodes_, always returns the same node,
230   // even if any of new nodes has time_ready smaller than the current node's.
231   std::vector<const NodeDef*> waiting_queue_;
232   // Comparator functor for heap; stl heap is max heap, so we use "greater than"
233   // functor for keeping the smallest time_ready node at the front of heap.
234   std::function<bool(const NodeDef*, const NodeDef*)> greater_;
235 
236   // NodeState structure from VirtualScheduler to get time_ready of ready nodes.
237   // Not owned by FirstReadyManager.
238   const std::unordered_map<const NodeDef*, NodeState>* node_map_;
239 };
240 
241 // FirstReadyManager picks a node with the minimum time_ready value.
242 // Behavior is deterministic when there are more than one nodes with the minimum
243 // time_ready value with unique node names as the tie-breaker.
244 class FirstReadyManager : public HeapReadyManager {
245  public:
FirstReadyManager()246   FirstReadyManager() : HeapReadyManager() {}
~FirstReadyManager()247   ~FirstReadyManager() override {}
248 
249  protected:
250   std::function<bool(const NodeDef*, const NodeDef*)> Greater() override;
251 };
252 
253 // PriorityReadyManager uses the given node priorities when picking up next node
254 // from all the ready nodes.
255 class PriorityReadyManager : public HeapReadyManager {
256  public:
PriorityReadyManager()257   PriorityReadyManager() : HeapReadyManager() {}
~PriorityReadyManager()258   ~PriorityReadyManager() override {}
259   void AddNode(const NodeDef* node) override;
260 
261   // Note this should be called after Init().
262   Status SetPriority(const std::unordered_map<string, int>& node_priority);
263 
264  protected:
265   std::function<bool(const NodeDef*, const NodeDef*)> Greater() override;
266 
267  private:
268   // A map from unique node name to priority. Lower number means higher
269   // priority.
270   std::unordered_map<string, int> node_priority_;
271 };
272 
273 // CompositeNodeManager has a few other NodeManagers: per-device LIFO for normal
274 // ops (neither _Send nor _Recv) and FirstReadyManagers for _Send ops and _Recv
275 // ops, and then it chooses FirstReady among the ops chosen from each
276 // internal NodeManagers. The objective is to maximize producer-consumer
277 // locality within device, while processing nodes across devices, including
278 // _Send and _Recv, fairly, in terms of their time_ready.
279 class CompositeNodeManager : public ReadyNodeManager {
280  public:
281   CompositeNodeManager();
~CompositeNodeManager()282   ~CompositeNodeManager() override {}
283 
284   Status Init(
285       const std::unordered_map<const NodeDef*, NodeState>* node_map) override;
286   void AddNode(const NodeDef* node) override;
287   const NodeDef* GetCurrNode() override;
288   void RemoveCurrNode() override;
289   bool Empty() const override;
290 
291  private:
292   // Internal ready node managers:
293   // LIFO for normal ops to maximize producer consumer locality.
294   // One LIFO per device.
295   std::unordered_map<string, LIFOManager> ops_lifo_map_;
296   // FirstReady for send and recv. Handle send and recv separately ensures that
297   // send and recv do not block previously read ops with LIFO schedule.
298   FirstReadyManager send_manager_;
299   FirstReadyManager recv_manager_;
300 
301   // NodeState structure from VirtualScheduler to get time_ready of ready nodes.
302   // Not owned by CompositeReadyManager.
303   const std::unordered_map<const NodeDef*, NodeState>* node_map_;
304 
305   // Cached curr node. Set back to nullptr from RemoveCurrNode().
306   const NodeDef* curr_node_;
307 };
308 
309 // Constructs a ready node manager from the given string.
310 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
311     const string& ready_node_manager);
312 
313 // The virtual scheduler emulates execution of nodes in a graph, considering
314 // dependencies, device, etc.
315 class VirtualScheduler {
316  public:
317   // Does not take ownership of cluster or ready_nodes.
318   VirtualScheduler(const bool use_static_shapes,
319                    const bool use_aggressive_shape_inference, Cluster* cluster,
320                    ReadyNodeManager* ready_nodes,
321                    std::unique_ptr<VirtualPlacer> placer);
322 
323   // Initializes the scheduler for the specific grappler item.
324   // Should be called immediately after the c'tor or when the scheduler will be
325   // reused for a new grappler item. All internal states of the scheduler
326   // related to the previous grappler item will be reset/cleared.
327   //
328   // This function should be called at least once after the scheduler is
329   // constructed. An uninitialized or failed-to-initialize scheduler will cause
330   // undefined behavior.
331   Status Init(const GrapplerItem* item);
332 
333   OpContext GetCurrNode() const;
334 
335   // Returns true if there is any node to be scheduled.
336   bool MarkCurrNodeExecuted(const Costs& node_costs);
337 
338   // Prints out summary of execution (timing, memory usage, etc.)
339   Costs Summary() const;
340   // Like the above, but writes detailed stats to RunMetadata.
341   // If metadata is nullptr, then just calls and return Summary().
342   Costs Summary(RunMetadata* metadata);
343   // Generates RunMetadata's step_stats and partition_graphs fields from results
344   // of the virtual execution of the graph.
345   void GenerateRunMetadata(RunMetadata* metadata);
346 
347   // Returns per device memory usage.
348   const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
349   const std::unordered_map<string, int64> GetPersistentMemoryUsage() const;
350 
351   // Returns VirtualScheduler (read only) device and node states.
GetDeviceStates()352   const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
353     return &device_;
354   }
GetNodeStates()355   const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
356     return &node_map_;
357   }
358 
enable_mem_usage_tracking()359   void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; }
360 
361  private:
362   // Methods called from Init(). Fails if initialize_ is set.
363   void MaybeUpdateInputOutput(const NodeDef* node);
364   NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
365   std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
366       const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
367       const string& input_name);
368   string DeviceName(const NodeDef* node) const;
369   string SanitizedDeviceName(const NodeDef* node) const;
370   string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
371 
372   // Helper methods.
373   void AddOutputNodesToReadyQueue(const NodeDef* node,
374                                   const Costs::Duration& curr_time);
375 
376   // Scheduler states:
377   ReadyNodeManager* ready_nodes_;  // Not owned.
378   std::unordered_map<const NodeDef*, NodeState> node_map_;
379   std::unordered_map<string, DeviceState> device_;
380 
381   // Pool of NodeDefs for SendRecv and Identity ops created.
382   std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
383 
384   // Stats:
385   // Op counts with key with input shape.
386   // Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]"
387   std::map<string, int> op_counts_;
388   // Individual op costs with key with input shape.
389   // Integer field for execution time in micro seconds.
390   // Boolean field for whether the cost is accurate.
391   std::map<string, std::pair<int, bool>> op_costs_;
392 
393   Costs graph_costs_;                   // Graph cost.
394   std::map<string, Costs> op_to_cost_;  // Per-op cost.
395 
396   // Auxiliary data structures for constructing NodeState and DeviceState.
397   std::unique_ptr<GraphProperties> graph_properties_;  // Initialized in Init().
398   Cluster* cluster_;                                   // Not owned.
399 
400   const GrapplerItem* grappler_item_;  // Not owned.
401   bool use_static_shapes_;
402   bool initialized_;
403   bool track_mem_usage_snapshot_;
404   const bool use_aggressive_shape_inference_;
405 
406   std::unique_ptr<VirtualPlacer> placer_;
407 };
408 
409 }  // namespace grappler
410 }  // end namespace tensorflow
411 
412 #endif  // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
413