• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
18 
19 #include <list>
20 #include <memory>
21 #include <unordered_map>
22 #include <unordered_set>
23 
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/step_stats.pb.h"
26 #include "tensorflow/core/grappler/costs/cost_estimator.h"
27 #include "tensorflow/core/grappler/costs/graph_properties.h"
28 #include "tensorflow/core/grappler/costs/op_context.h"
29 #include "tensorflow/core/grappler/costs/virtual_placer.h"
30 #include "tensorflow/core/grappler/grappler_item.h"
31 
32 namespace tensorflow {
33 namespace grappler {
34 
35 ABSL_CONST_INIT extern const char kAttrInputSrc[];
36 ABSL_CONST_INIT extern const char kAttrSrcDevice[];
37 ABSL_CONST_INIT extern const char kAttrDstDevice[];
38 ABSL_CONST_INIT extern const char kAttrTensorName[];
39 ABSL_CONST_INIT extern const char kChannelDevice[];
40 ABSL_CONST_INIT extern const char kStreaming[];
41 
42 struct NodeState {
43   // A node (i.e., an op) takes a set of input:port pairs and produces
44   // a set of output ports.
45 
46   // Cross references to input and output nodes from graphdef.
47   std::vector<std::pair<const NodeDef*, int>> inputs;  // Input, port pairs.
48   // List of output nodes (a list of nodes that takes this output port as input)
49   // keyed by port_num. Note that port_num -1 is used for control dependency.
50   std::unordered_map<int, std::vector<const NodeDef*>> outputs;
51 
52   // Info from GraphProperties.
53   std::vector<OpInfo::TensorProperties> input_properties;
54   std::vector<OpInfo::TensorProperties> output_properties;
55 
56   // Canonical device name used within VirtualScheduler.
57   string device_name;
58 
59   // States updated as scheduling nodes.
60   int num_inputs_ready;
61   std::unordered_map<int, int> num_outputs_executed;
62   Costs::Duration time_ready;
63   Costs::Duration time_scheduled;
64   Costs::Duration time_finished;
65   // Time that all the consumers are executed (hence, no need to keep this
66   // output in memory), keyed by port_num.
67   std::unordered_map<int, Costs::Duration> time_no_references;
68 
69   // Note that a node may have multiple output ports. The length of outputs,
70   // num_outputs_executed, and time_no_references should be
71   // identical when a NodeState is fully initialized.
72   // They should be 1 + output_properties.size() as we add [-1] for control
73   // dependency.
74 
75   // Node will be ready to be executed at time_ready, scheduled at
76   // time_scheduled, and finishes execution at time_finished.
77   // Each output port uses up memory space from time_scheduled to its
78   // time_no_references.
79 
80   Costs node_costs;  // Node costs per execution
TotalNodeCostsNodeState81   Costs TotalNodeCosts() const {
82     return MultiplyCosts(node_costs, execution_count);
83   }
84   // How many times this node has been executed, e.g. in a while loop.
85   int execution_count;
86 
87   // Output shape incompatible between shape annotation and shape inference.
88   bool shape_incompatible;
89 
NodeStateNodeState90   NodeState() {
91     num_inputs_ready = 0;
92     time_ready = Costs::Duration::max();
93     time_scheduled = Costs::Duration::max();
94     time_finished = Costs::Duration::max();
95     execution_count = 0;
96     shape_incompatible = false;
97     // Note that num_outputs_executed and time_no_references are not initialized
98     // here, since we don't know the size (i.e., # outputs for this node).
99   }
100 };
101 
102 struct DeviceState {
103   // Nodes executed on this device in execution order.
104   std::vector<const NodeDef*> nodes_executed;
105 
106   struct NodePairHash {
107    public:
operatorDeviceState::NodePairHash108     const std::size_t operator()(
109         const std::pair<const NodeDef*, int>& element) const {
110       return std::hash<const NodeDef*>()(element.first);
111     }
112   };
113 
114   // Nodes currently allocated in memory: set of NodeDef* and port_num pairs
115   // so that we can track which output of the node is in memory.
116   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
117       nodes_in_memory;
118 
119   // Nodes allocated in memory persistently: e.g., Variables.
120   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
121       persistent_nodes;
122 
123   // Snapshot of nodes_in_memory, when memory usage is at peak.
124   // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs.
125   std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash>
126       mem_usage_snapshot_at_peak;
127 
128   Costs device_costs;
129   std::map<string, Costs> op_to_cost;  // Per-op cost.
130 
131   int64 memory_usage;      // Current temporary memory usage
132   int64 max_memory_usage;  // Max temporary memory usage
133 
134   // Shape annotation statistics.
135   struct ShapeAnnotationStats {
136     // Number of ops with shape annotated.
137     int64 num_ops_annotated = 0;
138     // Number of ops executed multiple times (e.g. in a loop).
139     int64 num_ops_executed_more_than_once = 0;
140     // Number of ops executed: account for execution count.
141     int64 num_ops_executed = 0;
142     // Number of ops with dynamic shapes (e.g. shape changes in a loop).
143     int64 num_ops_with_dynamic_shapes = 0;
144     // Number of ops with incompatible shapes between annotation and shape
145     // inference.
146     int64 num_ops_with_incompatible_shapes = 0;
147   } shape_annotation_stats;
148 
DeviceStateDeviceState149   DeviceState() {
150     device_costs = Costs::ZeroCosts();
151     device_costs.num_ops_total = 0;
152     memory_usage = 0;
153     max_memory_usage = 0;
154   }
155 
GetCurrTimeDeviceState156   Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
157 };
158 
159 // ReadyNodeManager (abstract class):
160 // Keeps ready nodes and picks the best one to be scheduled.
161 class ReadyNodeManager {
162  public:
ReadyNodeManager()163   ReadyNodeManager() {}
~ReadyNodeManager()164   virtual ~ReadyNodeManager() {}
Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)165   virtual Status Init(
166       const std::unordered_map<const NodeDef*, NodeState>* node_map) {
167     return Status::OK();
168   }
169   virtual void AddNode(const NodeDef* node) = 0;
170   virtual const NodeDef* GetCurrNode() = 0;
171   virtual void RemoveCurrNode() = 0;
172   virtual bool Empty() const = 0;
173 };
174 
175 class FIFOManager : public ReadyNodeManager {
176  public:
FIFOManager()177   FIFOManager() : ReadyNodeManager() {}
~FIFOManager()178   ~FIFOManager() override {}
AddNode(const NodeDef * node)179   void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
GetCurrNode()180   const NodeDef* GetCurrNode() override {
181     CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
182     return nodes_.front();
183   }
RemoveCurrNode()184   void RemoveCurrNode() override { nodes_.pop_front(); }
Empty()185   bool Empty() const override { return nodes_.empty(); }
186 
187  private:
188   std::list<const NodeDef*> nodes_;
189 };
190 
191 // The LIFOManager schedules nodes by returning the last one added to the
192 // scheduler. A node is executed and then its ready outputs are newly added to
193 // the scheduler, so the LIFOManager will return outputs to a node following
194 // that node's execution.
195 class LIFOManager : public ReadyNodeManager {
196  public:
LIFOManager()197   LIFOManager() : ReadyNodeManager() {}
~LIFOManager()198   ~LIFOManager() override {}
199   void AddNode(const NodeDef* node) override;
200   const NodeDef* GetCurrNode() override;
201   void RemoveCurrNode() override;
Empty()202   bool Empty() const override { return nodes_.empty(); }
203 
204  private:
205   std::list<const NodeDef*> nodes_;
206   // Keep track of the current node being executed by saving its position.
207   // Necessary because nodes may be added to the end of the list while a node is
208   // executing, and we want to remove the correct node (the one that is
209   // executing) rather than the new ones being added.
210   std::list<const NodeDef*>::iterator curr_pos_ = nodes_.end();
211 };
212 
213 // Abstract class that maintains a heap/priority queue for scheduling ready
214 // nodes. Derived class needs to implement the Greater() function which returns
215 // the comparator for the heap.
216 class HeapReadyManager : public ReadyNodeManager {
217  public:
218   HeapReadyManager();
219   Status Init(
220       const std::unordered_map<const NodeDef*, NodeState>* node_map) override;
~HeapReadyManager()221   ~HeapReadyManager() override {}
222   void AddNode(const NodeDef* node) override;
223   const NodeDef* GetCurrNode() override;
224   void RemoveCurrNode() override;
225   bool Empty() const override;
226 
227  protected:
228   virtual std::function<bool(const NodeDef*, const NodeDef*)> Greater() = 0;
229 
230   // nodes_ is the main queue, where we construct heap, and the front is the
231   // current node.
232   std::vector<const NodeDef*> nodes_;
233 
234   // Comparator functor for heap; stl heap is max heap, so we use "greater than"
235   // functor for keeping the smallest time_ready node at the front of heap.
236   std::function<bool(const NodeDef*, const NodeDef*)> greater_;
237 
238   // NodeState structure from SchedulerState to get time_ready of ready nodes.
239   // Not owned by FirstReadyManager.
240   const std::unordered_map<const NodeDef*, NodeState>* node_map_;
241 
242   // Cached curr node. Set back to nullptr from RemoveCurrNode().
243   const NodeDef* curr_node_;
244 };
245 
246 // FirstReadyManager picks a node with the minimum time_ready value.
247 // Behavior is deterministic when there are more than one nodes with the minimum
248 // time_ready value with unique node names as the tie-breaker.
249 class FirstReadyManager : public HeapReadyManager {
250  public:
FirstReadyManager()251   FirstReadyManager() : HeapReadyManager() {}
~FirstReadyManager()252   ~FirstReadyManager() override {}
253 
254  protected:
255   std::function<bool(const NodeDef*, const NodeDef*)> Greater() override;
256 };
257 
258 // PriorityReadyManager uses the given node priorities when picking up next node
259 // from all the ready nodes.
260 class PriorityReadyManager : public HeapReadyManager {
261  public:
PriorityReadyManager()262   PriorityReadyManager() : HeapReadyManager() {}
~PriorityReadyManager()263   ~PriorityReadyManager() override {}
264   void AddNode(const NodeDef* node) override;
265 
266   // Note this should be called after Init().
267   Status SetPriority(const std::unordered_map<string, int>& node_priority);
268 
269  protected:
270   std::function<bool(const NodeDef*, const NodeDef*)> Greater() override;
271 
272  private:
273   // A map from unique node name to priority. Lower number means higher
274   // priority.
275   std::unordered_map<string, int> node_priority_;
276 };
277 
278 // CompositeNodeManager has a few other NodeManagers: per-device LIFO for normal
279 // ops (neither _Send nor _Recv) and FirstReadyManagers for _Send ops and _Recv
280 // ops, and then it chooses FirstReady among the ops chosen from each
281 // internal NodeManagers. The objective is to maximize producer-consumer
282 // locality within device, while processing nodes across devices, including
283 // _Send and _Recv, fairly, in terms of their time_ready.
284 class CompositeNodeManager : public ReadyNodeManager {
285  public:
286   CompositeNodeManager();
~CompositeNodeManager()287   ~CompositeNodeManager() override {}
288 
289   Status Init(
290       const std::unordered_map<const NodeDef*, NodeState>* node_map) override;
291   void AddNode(const NodeDef* node) override;
292   const NodeDef* GetCurrNode() override;
293   void RemoveCurrNode() override;
294   bool Empty() const override;
295 
296  private:
297   // Internal ready node managers:
298   // LIFO for normal ops to maximize producer consumer locality.
299   // One LIFO per device.
300   std::unordered_map<string, LIFOManager> ops_lifo_map_;
301   // FirstReady for send and recv. Handle send and recv separately ensures that
302   // send and recv do not block previously read ops with LIFO schedule.
303   FirstReadyManager send_manager_;
304   FirstReadyManager recv_manager_;
305 
306   // NodeState structure from SchedulerState to get time_ready of ready nodes.
307   // Not owned by CompositeReadyManager.
308   const std::unordered_map<const NodeDef*, NodeState>* node_map_;
309 
310   // Cached curr node. Set back to nullptr from RemoveCurrNode().
311   const NodeDef* curr_node_;
312 };
313 
314 // Constructs a ready node manager from the given string.
315 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
316     const string& ready_node_manager);
317 
318 // Encapsulates all of the various pieces uses to track state of a scheduler;
319 // enables reuse of all scheduler state-related utilities across different
320 // scheduler implementations.
321 class SchedulerState {
322  public:
323   SchedulerState(const bool use_static_shapes,
324                  const bool use_aggressive_shape_inference, Cluster* cluster,
325                  std::unique_ptr<VirtualPlacer> placer);
326   // Move constructor. Explicitly defined because it otherwise gets implicitly
327   // deleted. SchedulerState is a move-only class, as we have a <unique_ptr>
328   // for it in VirtualScheduler. A derivative of VirtualScheduler can move a
329   // <unique_ptr> SchedulerState to VirtualScheduler when it is constructed,
330   // which is where this move constructor is needed.
331   SchedulerState(SchedulerState&& arg) = default;
332   // We explicitly delete assinment and copy operators, this is done implicitly,
333   // but we state it here explicitly for clarity.
334   SchedulerState& operator=(SchedulerState&& arg) = delete;
335   SchedulerState(const SchedulerState&) = delete;
336   SchedulerState& operator=(const SchedulerState&) = delete;
337   // Destructor. Must be defined such that a derivative class can override it
338   // and allow proper desctruction of the derivative class. If this is not done
339   // properly, memory leaks can occur.
340   virtual ~SchedulerState();
341   // Sets up the graph while also performing some necessary transformations
342   // initial_nodes is the set of nodes (primary inputs) discovered by Init()
343   // which may be added by a ReadyNodeManager (or related/derivative scheduler)
344   // to begin node schedule and graph simulation.
345   Status Init(const GrapplerItem* item,
346               std::vector<const NodeDef*>* initial_nodes,
347               bool create_explicit_channel_device = true);
348 
349   virtual Costs Summary() const;
350   // Like the above, but writes detailed stats to RunMetadata.
351   // If metadata is nullptr, then just calls and return Summary().
352   virtual Costs Summary(RunMetadata* metadata);
353   // Generates RunMetadata's step_stats and partition_graphs fields from results
354   // of the virtual execution of the graph.
355   // TODO(rdegruijl) See if we can make this function and caller Summary()
356   // const.
357   void GenerateRunMetadata(RunMetadata* metadata);
358 
359   // Returns per device memory usage.
360   const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
361   const std::unordered_map<string, int64> GetPersistentMemoryUsage() const;
enable_mem_usage_tracking()362   void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; }
363   // Returns (read only) device and node states.
GetDeviceStates()364   const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
365     return &device_;
366   }
367 
GetNodeStates()368   const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
369     return &node_map_;
370   }
371 
372   OpContext CreateOpContext(const NodeDef* node) const;
373   std::vector<const NodeDef*> MarkNodeExecuted(const NodeDef* node,
374                                                const Costs& node_costs,
375                                                const OpContext& op_context);
376 
377   // Some getter functions.
GetGrapplerItem()378   const GrapplerItem* GetGrapplerItem() { return grappler_item_; }
GetGraphCost()379   Costs GetGraphCost() { return graph_costs_; }
GetCluster()380   Cluster* GetCluster() { return cluster_; }
GetUseStaticShape()381   bool GetUseStaticShape() { return use_static_shapes_; }
GetUseAggressiveShapeInference()382   bool GetUseAggressiveShapeInference() {
383     return use_aggressive_shape_inference_;
384   }
GetNodeMap()385   const std::unordered_map<const NodeDef*, NodeState>& GetNodeMap() {
386     return node_map_;
387   }
388 
389  protected:
390   // Assigns the time_scheduled in the NodeState of node to the current
391   // execution_time of the device executing this node.
392   void SetNodeStateTimeScheduled(const NodeDef* node);
393 
394   // This method can be used by a class derived from SchedulerState to
395   // access the device state map.
GetMutableDeviceState()396   std::unordered_map<string, DeviceState>* GetMutableDeviceState() {
397     return &device_;
398   }
399 
400  private:
401   // Methods called from Init(). Fails if initialize_ is set.
402 
403   void MaybeUpdateInputOutput(const NodeDef* node);
404   NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
405   // Creates a Send_ and Recv_ pair between from and to. The argument
406   // create_channel_device tells the function to create an explicit device for
407   // the channel.
408   std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
409       const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
410       const string& input_name, bool create_channel_device);
411   string DeviceName(const NodeDef* node) const;
412   string SanitizedDeviceName(const NodeDef* node) const;
413   string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
414 
415   // Helper methods.
416   void GetOutputNodes(const NodeDef* node, const Costs::Duration& curr_time,
417                       std::vector<const NodeDef*>* output_nodes);
418 
419   std::unordered_map<const NodeDef*, NodeState> node_map_;
420   std::unordered_map<string, DeviceState> device_;
421 
422   // Pool of NodeDefs for SendRecv and Identity ops created.
423   std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
424 
425   // Stats:
426   // Op counts with key with input shape.
427   // Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]"
428   std::map<string, int> op_counts_;
429   // Individual op costs with key with input shape.
430   // Integer field for execution time in micro seconds.
431   // Boolean field for whether the cost is accurate.
432   std::map<string, std::pair<int, bool>> op_costs_;
433 
434   Costs graph_costs_;                   // Graph cost.
435   std::map<string, Costs> op_to_cost_;  // Per-op cost.
436 
437   // Auxiliary data structures for constructing NodeState and DeviceState.
438   std::unique_ptr<GraphProperties> graph_properties_;  // Initialized in Init().
439   Cluster* cluster_;                                   // Not owned.
440   const GrapplerItem* grappler_item_;                  // Not owned.
441   bool use_static_shapes_;
442   bool initialized_;
443   bool track_mem_usage_snapshot_;
444   const bool use_aggressive_shape_inference_;
445   std::unique_ptr<VirtualPlacer> placer_;
446 };
447 
448 // The virtual scheduler emulates execution of nodes in a graph, considering
449 // dependencies, device, etc.
450 class VirtualScheduler {
451  public:
452   // Does not take ownership of cluster or ready_nodes.
453   VirtualScheduler(const bool use_static_shapes,
454                    const bool use_aggressive_shape_inference, Cluster* cluster,
455                    ReadyNodeManager* ready_nodes,
456                    std::unique_ptr<VirtualPlacer> placer);
457   // This constructor can be called by a derivative of VirtualScheduler to
458   // construct the base class. It lets VirtualScheduler take ownership of
459   // a new SchedulerState or a derivative thereof.
460   // Note that this constructor does not set a VirtualPlacer, in this
461   // constructor the VirtialPlacer is passed as a member of the SchedulerState
462   // that is passed as an argument.
463   VirtualScheduler(ReadyNodeManager* ready_nodes,
464                    std::unique_ptr<SchedulerState> scheduler_state);
465   virtual ~VirtualScheduler();
466 
467   // Initializes the scheduler for the specific grappler item.
468   // Should be called immediately after the c'tor or when the scheduler will be
469   // reused for a new grappler item. All internal states of the scheduler
470   // related to the previous grappler item will be reset/cleared.
471   //
472   // This function should be called at least once after the scheduler is
473   // constructed. An uninitialized or failed-to-initialize scheduler will cause
474   // undefined behavior.
475   virtual Status Init(const GrapplerItem* item);
476 
477   // Gets the current scheduled node for execution; the caller of this function
478   // can accordingly simulate the execution of the current scheduled node.
479   virtual OpContext GetCurrNode();
480   // Marks the current scheduled node as executed. Note that we should call this
481   // function only after the execution of the node has been simulated;
482   // node_costs_ capture the simulated costs of the node.
483   // Returns true if there is any node to be scheduled.
484   virtual bool MarkCurrNodeExecuted(const Costs& node_costs);
485 
486   // Prints out summary of execution (timing, memory usage, etc.)
Summary()487   Costs Summary() const { return scheduler_state_->Summary(); }
488   // Like the above, but writes detailed stats to RunMetadata.
489   // If metadata is nullptr, then just calls and return Summary().
Summary(RunMetadata * metadata)490   Costs Summary(RunMetadata* metadata) {
491     return scheduler_state_->Summary(metadata);
492   }
493   // Generates RunMetadata's step_stats and partition_graphs fields from results
494   // of the virtual execution of the graph.
GenerateRunMetadata(RunMetadata * metadata)495   void GenerateRunMetadata(RunMetadata* metadata) {
496     scheduler_state_->GenerateRunMetadata(metadata);
497   }
498   // Returns per device memory usage.
GetPeakMemoryUsage()499   const std::unordered_map<string, int64> GetPeakMemoryUsage() const {
500     return scheduler_state_->GetPeakMemoryUsage();
501   }
GetPersistentMemoryUsage()502   const std::unordered_map<string, int64> GetPersistentMemoryUsage() const {
503     return scheduler_state_->GetPersistentMemoryUsage();
504   }
505   // Returns VirtualScheduler (read only) device and node states.
GetDeviceStates()506   const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
507     return scheduler_state_->GetDeviceStates();
508   }
GetNodeStates()509   const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const {
510     return scheduler_state_->GetNodeStates();
511   }
enable_mem_usage_tracking()512   void enable_mem_usage_tracking() {
513     scheduler_state_->enable_mem_usage_tracking();
514   }
515 
516  protected:
517   // The state of the scheduler and the execution of the graph is encapsulated
518   // by the scheduler_state_ object.
519   std::unique_ptr<SchedulerState> scheduler_state_;
520   // ready_nodes_ is responsible for ordering the traversal of the graph.
521   ReadyNodeManager* ready_nodes_;  // Not owned.
522 };
523 
524 }  // namespace grappler
525 }  // end namespace tensorflow
526 
527 #endif  // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
528