• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
17 
18 #include <list>
19 #include <memory>
20 #include <string>
21 // TODO(b/114492873): Move this include into core/platform.
22 #include <thread>  // NOLINT
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/metrics.h"
29 #include "tensorflow/core/framework/model.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/gtl/cleanup.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/histogram/histogram.h"
34 #include "tensorflow/core/lib/random/random.h"
35 #include "tensorflow/core/platform/cpu_info.h"
36 #include "tensorflow/core/platform/env.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/path.h"
39 #include "tensorflow/core/platform/statusor.h"
40 #include "tensorflow/core/platform/strcat.h"
41 #include "tensorflow/core/platform/stringprintf.h"
42 
43 namespace tensorflow {
44 namespace data {
45 namespace model {
46 
47 // A constant that can be used to enable auto-tuning.
48 constexpr int64_t kAutotune = -1;
49 constexpr char kParallelism[] = "parallelism";
50 constexpr char kBufferSize[] = "buffer_size";
51 
52 // A key used to identify the input time of the model.
53 constexpr char kModelInputTimeKey[] = "model_input_time";
54 
55 enum class TraversalOrder {
56   BFS = 0,
57   REVERSE_BFS = 1,
58 };
59 
60 // Represents thread-safe state that can be shared between an input pipeline and
61 // the performance model.
62 struct SharedState {
63  public:
SharedStateSharedState64   SharedState(int64_t value, std::shared_ptr<mutex> mu,
65               std::shared_ptr<condition_variable> cond_var)
66       : value(value),
67         mu(std::move(mu)),
68         cond_var(std::move(cond_var)),
69         tunable(value == kAutotune) {}
70 
71   double value;
72   const std::shared_ptr<mutex> mu;
73   const std::shared_ptr<condition_variable> cond_var;
74   const bool tunable;
75 };
76 
77 // Represents a parameter.
78 struct Parameter {
ParameterParameter79   Parameter(const string& name, std::shared_ptr<SharedState> state, double min,
80             double max)
81       : name(name),
82         // Sometimes non-autotune nodes (with `autotune_=false`) may contain
83         // parameters (for example inputs of parallel interleave dataset which
84         // are not in the current cycle). To avoid unrealistic situation
85         // (say `buffer_size=-1` or `parallelism=-1`) in the optimization
86         // computation, if the state value is `kAutotune=-1` (just to indicate
87         // the `SharedState` is tunable), we initialize the parameter value to
88         // be the minimal value of the state.
89         value(state->value == kAutotune ? min : state->value),
90         min(min),
91         max(max),
92         state(std::move(state)) {}
93 
94   // Human-readable name of the parameter.
95   const string name;
96 
97   // Identifies the model value of the parameter. This can be different from
98   // the actual value (e.g. during optimization search).
99   double value;
100 
101   // Identifies the minimum value of the parameter.
102   const double min;
103 
104   // Identifies the maximum value of the parameter.
105   const double max;
106 
107   // Shared state of the parameter.
108   std::shared_ptr<SharedState> state;
109 };
110 
111 std::shared_ptr<Parameter> MakeParameter(const string& name,
112                                          std::shared_ptr<SharedState> state,
113                                          double min, double max);
114 
115 // Abstract representation of a TensorFlow input pipeline node. It collects
116 // information about inputs to this node, processing time spent executing the
117 // node logic, number of elements produced by the node, various other
118 // information (e.g. batch size or execution parallelism).
119 //
120 // Developers of tf.data transformations are not expected to interact with
121 // this class directly. Boiler plate code for creating the abstract
122 // representation of the input pipeline and collecting common information has
123 // been added to the implementation of `DatasetBase` and `DatasetBaseIterator`
124 // respectively.
125 //
126 // In addition, `DatasetBaseIterator` provides wrappers that can be used for
127 // transformation-specific information collection. The `SetMetadata` wrapper
128 // can be used to pass arbitrary metadata to the modeling framework, while the
129 // `StartWork` and `StopWork` wrappers should be used to correctly account for
130 // processing time of multi-threaded transformation that yield the CPU; such
131 // transformations should invoke `StartWork()` when a transformation thread
132 // starts executing (e.g. when created or woken up) and `StopWork()` when a
133 // transformation thread stops executing (e.g. when returning or waiting).
134 class Node {
135  public:
136   // Arguments for `Node` constructor.
137   struct Args {
138     int64 id;
139     string name;
140     std::shared_ptr<Node> output;
141   };
142 
143   using Factory = std::function<std::shared_ptr<Node>(Args)>;
144   using NodeVector = std::vector<std::shared_ptr<Node>>;
145   using NodePairList =
146       std::list<std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>>;
147   using ModelParameters =
148       std::vector<std::pair<string, std::shared_ptr<Parameter>>>;
149   using NodeValues = absl::flat_hash_map<string, double>;
150   using ParameterGradients =
151       absl::flat_hash_map<std::pair<string, string>, double>;
152 
Node(Args args)153   explicit Node(Args args)
154       : id_(args.id),
155         name_(std::move(args.name)),
156         autotune_(true),
157         buffered_bytes_(0),
158         buffered_elements_(0),
159         bytes_consumed_(0),
160         bytes_produced_(0),
161         num_elements_(0),
162         processing_time_(0),
163         record_metrics_(true),
164         metrics_(name_),
165         output_(args.output.get()) {}
166 
~Node()167   virtual ~Node() {
168     // Clear the sub-nodes instead of relying on implicit shared pointer
169     // destructor to avoid potential stack overflow when the tree is deep.
170     std::deque<std::shared_ptr<Node>> queue;
171     {
172       mutex_lock l(mu_);
173       while (inputs_.size() > 0) {
174         queue.push_back(inputs_.front());
175         inputs_.pop_front();
176       }
177     }
178     while (!queue.empty()) {
179       auto node = queue.back();
180       queue.pop_back();
181       {
182         mutex_lock l(node->mu_);
183         while (node->inputs_.size() > 0) {
184           queue.push_back(node->inputs_.front());
185           node->inputs_.pop_front();
186         }
187       }
188     }
189 
190     FlushMetrics();
191   }
192 
193   // Adds an input.
add_input(std::shared_ptr<Node> node)194   void add_input(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_) {
195     mutex_lock l(mu_);
196     inputs_.push_back(node);
197   }
198 
199   // Increments the aggregate processing time by the given delta.
add_processing_time(int64_t delta)200   void add_processing_time(int64_t delta) TF_LOCKS_EXCLUDED(mu_) {
201     processing_time_ += delta;
202   }
203 
204   // Returns an indication whether autotuning is enabled for this node.
autotune()205   bool autotune() const TF_LOCKS_EXCLUDED(mu_) {
206     return autotune_;
207   }
208 
209   // Returns the number of bytes stored in this node's buffer.
buffered_bytes()210   int64 buffered_bytes() const TF_LOCKS_EXCLUDED(mu_) {
211     return buffered_bytes_;
212   }
213 
214   // Returns the number of elements stored in this node's buffer.
buffered_elements()215   int64 buffered_elements() const TF_LOCKS_EXCLUDED(mu_) {
216     return buffered_elements_;
217   }
218 
219   // Returns the number of bytes consumed by the node.
bytes_consumed()220   int64 bytes_consumed() const TF_LOCKS_EXCLUDED(mu_) {
221     return bytes_consumed_;
222   }
223 
224   // Returns the number of bytes produced by the node.
bytes_produced()225   int64 bytes_produced() const TF_LOCKS_EXCLUDED(mu_) {
226     return bytes_produced_;
227   }
228 
229   // Indicates whether the node has tunable parameters.
has_tunable_parameters()230   bool has_tunable_parameters() const TF_LOCKS_EXCLUDED(mu_) {
231     tf_shared_lock l(mu_);
232     for (const auto& pair : parameters_) {
233       if (pair.second->state->tunable) return true;
234     }
235     return false;
236   }
237 
238   // Returns the unique node ID.
id()239   int64 id() const TF_LOCKS_EXCLUDED(mu_) { return id_; }
240 
241   // Returns the node inputs.
inputs()242   std::list<std::shared_ptr<Node>> inputs() const TF_LOCKS_EXCLUDED(mu_) {
243     tf_shared_lock l(mu_);
244     return inputs_;
245   }
246 
247   // Returns a longer node name that is guaranteed to be unique.
long_name()248   string long_name() const { return strings::StrCat(name_, "(id:", id_, ")"); }
249 
250   // Returns the node name.
name()251   const string& name() const { return name_; }
252 
253   // Returns the number of elements produced by the node.
num_elements()254   int64 num_elements() const TF_LOCKS_EXCLUDED(mu_) {
255     return num_elements_;
256   }
257 
258   // Returns the node output.
output()259   Node* output() const { return output_; }
260 
261   // Returns the parameter value.
parameter_value(const string & name)262   double parameter_value(const string& name) const TF_LOCKS_EXCLUDED(mu_) {
263     tf_shared_lock l(mu_);
264     return parameters_.at(name)->state->value;
265   }
266 
267   // Returns the aggregate processing time.
processing_time()268   int64 processing_time() const TF_LOCKS_EXCLUDED(mu_) {
269     return processing_time_;
270   }
271 
272   // Records that the node consumed the given number of bytes.
record_bytes_consumed(int64_t num_bytes)273   void record_bytes_consumed(int64_t num_bytes) {
274     bytes_consumed_ += num_bytes;
275   }
276 
277   // Records that the node produced the given number of bytes.
record_bytes_produced(int64_t num_bytes)278   void record_bytes_produced(int64_t num_bytes) {
279     bytes_produced_ += num_bytes;
280   }
281 
282   // Records the change in this node's buffer.
record_buffer_event(int64_t bytes_delta,int64_t elements_delta)283   void record_buffer_event(int64_t bytes_delta, int64_t elements_delta) {
284     buffered_bytes_ += bytes_delta;
285     buffered_elements_ += elements_delta;
286   }
287 
288   // Records that the node produced an element.
record_element()289   void record_element() TF_LOCKS_EXCLUDED(mu_) {
290     num_elements_++;
291   }
292 
293   // Records that a node thread has started executing.
record_start(int64_t time_nanos)294   void record_start(int64_t time_nanos) TF_LOCKS_EXCLUDED(mu_) {
295     DCHECK_EQ(work_start_, 0);
296     work_start_ = time_nanos;
297   }
298 
299   // Records that a node thread has stopped executing.
record_stop(int64_t time_nanos)300   void record_stop(int64_t time_nanos) TF_LOCKS_EXCLUDED(mu_) {
301     // TODO(jsimsa): Use DCHECK_NE(work_start_, 0) here.
302     if (work_start_ != 0) {
303       processing_time_ += time_nanos - work_start_;
304       work_start_ = 0;
305     } else {
306       VLOG(1) << "Encountered a stop event without a matching start event.";
307     }
308   }
309 
310   // Returns whether work is currently being recorded, i.e. whether we are
311   // currently between a `record_start` and a `record_stop`.
is_recording()312   bool is_recording() TF_LOCKS_EXCLUDED(mu_) { return work_start_ > 0; }
313 
314   // Removes an input.
remove_input(std::shared_ptr<Node> input)315   void remove_input(std::shared_ptr<Node> input) TF_LOCKS_EXCLUDED(mu_) {
316     mutex_lock l(mu_);
317     inputs_.remove(input);
318   }
319 
320   // Sets the value that determines whether autotuning is enabled for this node.
set_autotune(bool autotune)321   void set_autotune(bool autotune) TF_LOCKS_EXCLUDED(mu_) {
322     autotune_.store(autotune);
323   }
324 
325   // Given the average time between events when the elements in the buffer are
326   // produced (`producer_time`), the average time between events when elements
327   // in the buffer are consumed (`consumer_time`) and the buffer size, the
328   // method computes the expected time an consumer event will have to wait.
329   //
330   // The wait time is approximated as the product of the probability the buffer
331   // will be empty and the time it takes to produce an element into the buffer.
332   //
333   // The formula used for computing the probability is derived by modeling the
334   // problem as an M/M/1/K queue
335   // (https://en.wikipedia.org/wiki/Birth%E2%80%93death_process#M/M/1/K_queue).
336   //
337   // Collects derivatives of `ComputeWaitTime` w.r.t `producer_time`,
338   // `consumer_time' and `buffer_size` if the corresponding pointers are not
339   // `nullptr`.
340   static double ComputeWaitTime(const double& producer_time,
341                                 const double& consumer_time,
342                                 const double& buffer_size,
343                                 double* producer_time_derivative,
344                                 double* consumer_time_derivative,
345                                 double* buffer_size_derivative);
346 
347   // Collects tunable parameters in the subtree rooted in this node.
348   ModelParameters CollectTunableParameters() const TF_LOCKS_EXCLUDED(mu_);
349 
350   // Returns a human-readable representation of this node.
351   string DebugString() const TF_LOCKS_EXCLUDED(mu_);
352 
353   // Flushes the metrics recorded by this node.
354   void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
355 
356   // Returns the per-element output time for this node and if `gradients` is not
357   // `nullptr`, collects the output time gradient w.r.t. tunable parameters of
358   // the subtree rooted in this node.
359   double OutputTime(NodeValues* input_times,
360                     ParameterGradients* gradients) const TF_LOCKS_EXCLUDED(mu_);
361 
362   // Returns a copy of this node, making a deep copy of its inputs and a
363   // shallow copy of its tunable parameters.
364   //
365   // The purpose for this method is to allow the model optimization logic to
366   // operate over immutable state while allowing concurrent model updates.
367   std::shared_ptr<Node> Snapshot() const TF_LOCKS_EXCLUDED(mu_);
368 
369   // Returns the per-element processing time spent in this node.
370   double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_);
371 
372   // Returns the total number of bytes buffered in all nodes in the subtree for
373   // which autotuning is enabled.
374   double TotalBufferedBytes() const TF_LOCKS_EXCLUDED(mu_);
375 
376   // Collects the total buffer limit of all nodes in the subtree for which
377   // autotuning is enabled. This number represents the amount of memory that
378   // would be used by the subtree nodes if all of their buffers were full.
379   double TotalMaximumBufferedBytes() const TF_LOCKS_EXCLUDED(mu_);
380 
381   // Returns the per-element CPU time spent in the subtree rooted in this node.
382   // If `processing_times` is not `nullptr`, collects the per-element CPU time
383   // spent in each node of the subtree.
384   double TotalProcessingTime(NodeValues* processing_times)
385       TF_LOCKS_EXCLUDED(mu_);
386 
387   // Produces a proto for this node. Does not produce a proto for input nodes.
388   virtual Status ToProto(ModelProto::Node* node_proto) const;
389 
390   // Restores a node from the proto. Does not restore input nodes.
391   static Status FromProto(ModelProto::Node node_proto,
392                           std::shared_ptr<Node> output,
393                           std::shared_ptr<Node>* node);
394 
395  protected:
396   // Used for (incrementally) recording metrics. The class is thread-safe.
397   class Metrics {
398    public:
Metrics(const string & name)399     explicit Metrics(const string& name)
400         : bytes_consumed_counter_(metrics::GetTFDataBytesConsumedCounter(name)),
401           bytes_produced_counter_(metrics::GetTFDataBytesProducedCounter(name)),
402           num_elements_counter_(metrics::GetTFDataElementsCounter(name)),
403           recorded_bytes_consumed_(0),
404           recorded_bytes_produced_(0),
405           recorded_num_elements_(0) {}
406 
407     // Expects the total number of bytes consumed and records the delta since
408     // last invocation.
record_bytes_consumed(int64_t total_bytes)409     void record_bytes_consumed(int64_t total_bytes) {
410       int64_t delta =
411           total_bytes - recorded_bytes_consumed_.exchange(total_bytes);
412       bytes_consumed_counter_->IncrementBy(delta);
413     }
414 
415     // Expects the total number of bytes produced and records the delta since
416     // last invocation.
record_bytes_produced(int64_t total_bytes)417     void record_bytes_produced(int64_t total_bytes) {
418       int64_t delta =
419           total_bytes - recorded_bytes_produced_.exchange(total_bytes);
420       bytes_produced_counter_->IncrementBy(delta);
421     }
422 
423     // Expects the total number of elements produced and records the delta since
424     // last invocation.
record_num_elements(int64_t total_elements)425     void record_num_elements(int64_t total_elements) {
426       int64_t delta =
427           total_elements - recorded_num_elements_.exchange(total_elements);
428       num_elements_counter_->IncrementBy(delta);
429     }
430 
431    private:
432     monitoring::CounterCell* const bytes_consumed_counter_;
433     monitoring::CounterCell* const bytes_produced_counter_;
434     monitoring::CounterCell* const num_elements_counter_;
435     std::atomic<int64> recorded_bytes_consumed_;
436     std::atomic<int64> recorded_bytes_produced_;
437     std::atomic<int64> recorded_num_elements_;
438   };
439 
440   // Returns the number of inputs.
num_inputs()441   int64 num_inputs() const TF_SHARED_LOCKS_REQUIRED(mu_) {
442     int64_t num_inputs = 0;
443     for (auto& input : inputs_) {
444       // Inputs for which autotuning is disabled are excluded.
445       if (input->autotune()) {
446         ++num_inputs;
447       }
448     }
449     return num_inputs;
450   }
451 
452   // Creates a clone of this node.
453   virtual std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const
454       TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
455 
456   // Returns the average size of an element buffered in this node.
457   double AverageBufferedElementSize() const TF_SHARED_LOCKS_REQUIRED(mu_);
458 
459   // Returns the sum of per-element output time for the tunable inputs of this
460   // node.
461   double OutputTimeForInputs(const NodeValues& output_times) const
462       TF_SHARED_LOCKS_REQUIRED(mu_);
463 
464   // Returns the sum of output time gradient w.r.t. input time for the tunable
465   // inputs of this node.
466   double OutputTimeGradientsForInputs(const NodeValues& output_time_gradients)
467       const TF_SHARED_LOCKS_REQUIRED(mu_);
468 
469   // Computes the input time for this node and stores it in `input_times`.
470   virtual void InputTimeLocked(NodeValues* input_times) const
471       TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
472 
473   // Computes the per-element output time for this node and stores it in
474   // `output_times`. If `gradients` is not `nullptr`, computes the output time
475   // gradient w.r.t. tunable parameters of the subtree rooted in this node and
476   // stores it in `gradients`, also computes the output time gradient w.r.t.
477   // input time and stores it in `output_time_gradients`.
478   virtual void OutputTimeLocked(const NodeValues& input_times,
479                                 ParameterGradients* gradients,
480                                 NodeValues* output_times,
481                                 NodeValues* output_time_gradients) const
482       TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
483 
484   // Returns the sum of per-element processing time for the inputs of this node
485   // by adding values for input nodes in `total_processing_times`. Processing
486   // time for a given input is a weighted combination of a statistic based on
487   // history of input processing time and the actual time. This is done to
488   // improve accuracy of processing time estimation for newly created inputs.
489   //
490   // Uniform distribution of per-element processing times across different
491   // inputs is assumed.
492   double TotalProcessingTimeForInputs(const NodeValues& total_processing_times)
493       TF_SHARED_LOCKS_REQUIRED(mu_);
494 
495   // Returns the per-element processing time spent in this node.
496   double SelfProcessingTimeLocked() const TF_SHARED_LOCKS_REQUIRED(mu_);
497 
498   // Computes the per-element CPU time spent in the subtree rooted in this node
499   // and stores it in `total_processing_times`. If `processing_times` is not
500   // `nullptr`, collects the per-element CPU time spent in each node of the
501   // subtree.
502   virtual void TotalProcessingTimeLocked(NodeValues* processing_times,
503                                          NodeValues* total_processing_times)
504       TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
505 
506   // Returns a vector of nodes of the subtree rooted in this node. The nodes are
507   // either in breadth-first search or reverse breadth-first search order
508   // depending on the `order` argument. The nodes are collected based on the
509   // results of the `collect_node` predicate: if the predicate returns `false`
510   // for a given node, then the subtree rooted in this node is excluded. The
511   // root node itself is not collected.
512   NodeVector CollectNodes(TraversalOrder order,
513                           bool collect_node(const std::shared_ptr<Node>)) const
514       TF_SHARED_LOCKS_REQUIRED(mu_);
515 
516   // Collects tunable parameters in the subtree rooted in this node assuming
517   // mutex locked.
518   ModelParameters CollectTunableParametersLocked() const
519       TF_SHARED_LOCKS_REQUIRED(mu_);
520 
521   // Collect tunable parameters on the nodes which have recorded elements.
522   void CollectTunableParametersHelper(ModelParameters* parameters) const
523       TF_SHARED_LOCKS_REQUIRED(mu_);
524 
525   // Build up debug string for the node and store in the debug strings map.
526   void DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings)
527       const TF_SHARED_LOCKS_REQUIRED(mu_);
528 
529   // Copy the node and add the (input, copy) pairs to the NodePairList.
530   std::shared_ptr<Node> SnapshotHelper(std::shared_ptr<Node> cloned_output,
531                                        NodePairList* node_pairs) const;
532 
533   // Compute total buffered bytes for the node and store in the total bytes map.
534   void TotalBufferedBytesHelper(NodeValues* total_bytes) const
535       TF_SHARED_LOCKS_REQUIRED(mu_);
536 
537   // Compute total maximum buffered bytes for the node and store in the total
538   // bytes map.
539   void TotalMaximumBufferedBytesHelper(NodeValues* total_bytes) const
540       TF_SHARED_LOCKS_REQUIRED(mu_);
541 
542   // Compute and return the maximum buffered bytes on the node itself. By
543   // default non-tunable nodes are assumed not to buffer any bytes, so the
544   // tunable nodes as subclasses are expected to override this method to ensure
545   // that the optimization algorithm respects the memory budget.
546   virtual double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_);
547 
548   // Restores node from the proto. Note that this is not done recursively, i.e.
549   // input nodes are not restored.
550   static Status FromProtoHelper(ModelProto::Node node_proto,
551                                 std::shared_ptr<Node> node);
552 
553   // Stores the time passed to the last call to `Node::record_start()` on the
554   // current thread.
555   //
556   // NOTE: This thread-local variable is shared between all instances of `Node`
557   // on which the same thread calls `record_start()` or `record_stop()`. It
558   // relies on the invariant that at most one `Node` can be "active" on a
559   // particular thread at any time. Therefore if `n->record_start()` is called
560   // on thread `t`, then `n->record_stop()` must be called before another call
561   // to `Node::record_start()` (for any node).
562   static thread_local int64_t work_start_;  // Will be initialized to zero.
563 
564   mutable mutex mu_;
565   const int64 id_;
566   const string name_;
567 
568   // Indicates whether the subtree rooted in this node should be included in
569   // autotuning. In particular, if this is `false`, then the subtree is excluded
570   // from computation of output time and processing time.
571   std::atomic<bool> autotune_;
572   std::atomic<int64> buffered_bytes_;
573   std::atomic<int64> buffered_elements_;
574   std::atomic<int64> bytes_consumed_;
575   std::atomic<int64> bytes_produced_;
576   std::atomic<int64> num_elements_;
577   std::atomic<int64> processing_time_;
578   std::atomic<bool> record_metrics_;
579   Metrics metrics_;
580   absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_
581       TF_GUARDED_BY(mu_);
582 
583   // Statistic of inputs processing time history.
584   double input_processing_time_sum_ = 0.0L;
585   int64 input_processing_time_count_ = 0;
586 
587   // Inputs of this node. These can represent an iterator created from the input
588   // dataset but also other input iterators (e.g. created by the user-defined
589   // functions of `flat_map` or `interleave`).
590   std::list<std::shared_ptr<Node>> inputs_ TF_GUARDED_BY(mu_);
591 
592   // The reference to the output node is not owned so that deletion of a
593   // node results in recursive deletion of the subtree rooted in the node.
594   Node* const output_;
595 };
596 
597 // InterleaveMany is used to model datasets whose inputs are used to create
598 // datasets whose elements are then interleaved.
599 std::shared_ptr<Node> MakeInterleaveManyNode(Node::Args args);
600 
601 // AsyncInterleaveMany nodes are the asynchronous version of InterleaveMany
602 // nodes.
603 std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
604     Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters);
605 
606 // KnownMany nodes model datasets that synchronously consume known number of
607 // input element per output element.
608 std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio);
609 
610 // AsyncKnownRatio nodes are the asynchronous version of KnownRate nodes.
611 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
612     Node::Args args, double ratio, double memory_ratio,
613     std::vector<std::shared_ptr<Parameter>> parameters);
614 
615 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
616     Node::Args args, double ratio,
617     std::vector<std::shared_ptr<Parameter>> parameters);
618 
619 // Source nodes represent data sources.
620 std::shared_ptr<Node> MakeSourceNode(Node::Args args);
621 
622 // UnknownMany nodes represent datasets that synchronously consume an
623 // unknown number of input elements per output.
624 //
625 // Unlike KnownRatio nodes which expect the ratio between inputs and outputs is
626 // specified as a parameter, UnknownRatio estimates the ratio empirically.
627 std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args);
628 
629 // Unknown nodes represent datasets for which we do not have a model. It acts
630 // as pass-through between inputs and output.
631 std::shared_ptr<Node> MakeUnknownNode(Node::Args args);
632 
633 // Abstract representation of a TensorFlow input pipeline that can be used
634 // for collecting runtime information and optimizing performance. It collects
635 // runtime information about execution of the input pipeline that is used to
636 // create a performance model, which is in turn used to identify optimal values
637 // of tunable parameters.
638 //
639 // Developers of tf.data transformations are not expected to interact with this
640 // class directly. Boiler plate code for creating the abstract representation of
641 // the input pipeline and collecting runtime information has been added to the
642 // implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
643 class Model {
644  public:
645   using OptimizationParams = ModelProto::OptimizationParams;
646   using ModelParameters = Node::ModelParameters;
647   using NodeValues = Node::NodeValues;
648   using ParameterGradients = Node::ParameterGradients;
649 
650   Model();
651   ~Model();
652 
653   // Indicates whether to collect resource usage.
collect_resource_usage()654   bool collect_resource_usage() const { return collect_resource_usage_; }
655 
656   // Returns a pointer to the model's output node.
output()657   const std::shared_ptr<Node> output() {
658     mutex_lock l(mu_);
659     return output_;
660   }
661 
662   // Adds a node with the given name and given parent.
663   void AddNode(Node::Factory factory, const string& name,
664                std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node)
665       TF_LOCKS_EXCLUDED(mu_);
666 
667   // Returns a human-readable string representation of the model. This method
668   // can be invoked automatically by monitoring gauges and to avoid frequent
669   // recomputation, the implementation caches the result.
670   std::string DebugString();
671 
672   // Uses the given algorithm and resource budgets to periodically perform the
673   // autotuning optimization.
674   //
675   // To terminate the execution of the optimization loop, the caller needs to
676   // invoke `cancellation_mgr->StartCancel()`.
677   Status OptimizeLoop(AutotuneAlgorithm algorithm, int64_t cpu_budget,
678                       int64_t ram_budget,
679                       CancellationManager* cancellation_manager);
680 
681   // Uses the given algorithm and resource budgets to perform the autotuning
682   // optimization.
683   void Optimize(AutotuneAlgorithm algorithm, int64_t cpu_budget,
684                 int64_t ram_budget, double model_input_time,
685                 CancellationManager* cancellation_manager);
686 
687   // Collects the output time and if `gradients` is not `nullptr`, the output
688   // time gradient w.r.t. tunable parameters of the subtree rooted in the given
689   // node.
690   double OutputTime(std::shared_ptr<Node> node, double model_input_time,
691                     ParameterGradients* gradients);
692 
693   // Removes the given node.
694   void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_);
695 
696   // Produces a proto for this model.
697   Status ToProto(ModelProto* model_proto);
698 
699   // Restores a model from the proto.
700   static Status FromProto(ModelProto model_proto,
701                           std::unique_ptr<Model>* model);
702 
703   // Saves this model with a given snapshot and its optimization parameters to a
704   // file. Note that the file directory must already exist.
705   Status Save(const string& fname, std::shared_ptr<Node> snapshot,
706               const OptimizationParams& optimization_params);
707 
708   // Loads a model and its optimization parameters from a file with the given
709   // name.
710   static Status Load(const string& fname, std::unique_ptr<Model>* model,
711                      OptimizationParams* optimization_params);
712 
713  private:
714   static constexpr int64_t kOptimizationPeriodMinMs = 10;
715   static constexpr int64_t kOptimizationPeriodMaxMs =
716       60 * EnvTime::kSecondsToMillis;
717 
718   // Collects tunable parameters in the tree rooted in the given node, returning
719   // a vector which contains pairs of node names and tunable parameters.
720   ModelParameters CollectTunableParameters(std::shared_ptr<Node> node);
721 
722   // Flushes metrics recorded by the model.
723   void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
724 
725   // This optimization algorithm starts by setting all tunable parallelism
726   // parameters to the minimum value. It then repeatedly identifies the
727   // parameter whose increase in parallelism decreases the output time the most.
728   // This process is repeated until all parameters reach their maximum values or
729   // the projected output time is less than or equal to the processing time
730   // needed to produce an element divided by CPU budget.
731   void OptimizeHillClimb(std::shared_ptr<Node> snapshot,
732                          const OptimizationParams& optimization_params,
733                          CancellationManager* cancellation_manager);
734 
735   // This optimization algorithm starts by setting all tunable parallelism
736   // parameters to the minimum value. It then improves current parameters by
737   // making a step in the direction opposite to the gradient of `OutputTime` and
738   // projecting resulting values on the feasible intervals. Improvement step is
739   // repeated until either the output time improvement is smaller than threshold
740   // value or the output time is less than the processing time needed to produce
741   // an element divided by CPU budget.
742   void OptimizeGradientDescent(std::shared_ptr<Node> snapshot,
743                                const OptimizationParams& optimization_params,
744                                CancellationManager* cancellation_manager);
745 
746   // Determines if we should stop the gradient descent optimization iterations
747   // based on number of increasable parameters, CPU budget, RAM budget and
748   // current resource usage.
749   bool ShouldStop(int64_t cpu_budget, int64_t ram_budget,
750                   const ModelParameters& parameters,
751                   const ModelParameters& parallelism_parameters,
752                   const ModelParameters& buffer_size_parameters,
753                   std::shared_ptr<Node> snapshot, bool* cpu_budget_reached);
754 
755   // Collects the processing time for the given node.
756   double TotalProcessingTime(std::shared_ptr<Node> node);
757 
758   // Collects the total number of bytes buffered in all nodes in the subtree
759   // rooted in the given node for which autotuning is enabled.
760   double TotalBufferedBytes(std::shared_ptr<Node> node);
761 
762   // Collects the total buffer limit of all nodes in the subtree rooted in the
763   // given node for which autotuning is enabled. This number represents the
764   // amount of memory that would be used by the subtree nodes if all of their
765   // buffers were full.
766   double TotalMaximumBufferedBytes(std::shared_ptr<Node> node);
767 
768   // Used for coordination between different input pipeline threads. Exclusive
769   // access is required only when adding or removing nodes. Concurrent access to
770   // existing nodes is protected by a node mutex.
771   mutex mu_;
772   // Used for coordinating the optimization loop and model modifications.
773   condition_variable optimize_cond_var_;
774   int64 id_counter_ TF_GUARDED_BY(mu_) = 1;
775   std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_) = nullptr;
776 
777   // Indicates whether the modeling framework should collect resource usage
778   // (e.g. CPU, memory). The logic for collecting this information assumes that
779   // the collection is not repeatedly disabled and enabled. As a consequence,
780   // the implementation starts collecting resource usage when it encounters a
781   // tunable parameter (because the information is used for tuning the value of
782   // the parameter) and never stops.
783   std::atomic<bool> collect_resource_usage_;
784 
785   // Determines the time the optimization loop should wait between
786   // running optimizations.
787   int64 optimization_period_ms_ TF_GUARDED_BY(mu_);
788 
789   // Gauge cell that can be used to collect the state of the model.
790   monitoring::GaugeCell<std::function<std::string()>>* model_gauge_cell_ =
791       nullptr;
792   // Time use for rate limitting the recomputation of human-readable string
793   // represention of the model.
794   absl::Time cache_until_ = absl::InfinitePast();
795   // Cached result of the `DebugString()` invocation used to implement rate
796   // limitting of the computation.
797   std::string cached_debug_string_ = "";
798 };
799 
800 }  // namespace model
801 }  // namespace data
802 }  // namespace tensorflow
803 
804 #endif  // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
805