• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
17 
18 #include <list>
19 #include <memory>
20 #include <string>
21 // TODO(b/114492873): Move this include into core/platform.
22 #include <thread>  // NOLINT
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/metrics.h"
29 #include "tensorflow/core/framework/model.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/gtl/cleanup.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/histogram/histogram.h"
34 #include "tensorflow/core/lib/random/random.h"
35 #include "tensorflow/core/platform/cpu_info.h"
36 #include "tensorflow/core/platform/env.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/path.h"
39 
40 namespace tensorflow {
41 namespace data {
42 namespace model {
43 
44 // A constant that can be used to enable auto-tuning.
45 constexpr int64 kAutotune = -1;
46 constexpr char kParallelism[] = "parallelism";
47 constexpr char kBufferSize[] = "buffer_size";
48 
49 // A key used to identify the input time of the model.
50 constexpr char kModelInputTimeKey[] = "model_input_time";
51 
52 enum class TraversalOrder {
53   BFS = 0,
54   REVERSE_BFS = 1,
55 };
56 
57 // Represents thread-safe state that can be shared between an input pipeline and
58 // the performance model.
59 struct SharedState {
60  public:
SharedStateSharedState61   SharedState(int64 value, std::shared_ptr<mutex> mu,
62               std::shared_ptr<condition_variable> cond_var)
63       : value(value),
64         mu(std::move(mu)),
65         cond_var(std::move(cond_var)),
66         tunable(value == kAutotune) {}
67 
68   double value;
69   const std::shared_ptr<mutex> mu;
70   const std::shared_ptr<condition_variable> cond_var;
71   const bool tunable;
72 };
73 
74 // Represents a parameter.
75 struct Parameter {
ParameterParameter76   Parameter(const string& name, std::shared_ptr<SharedState> state, double min,
77             double max)
78       : name(name),
79         // Sometimes non-autotune nodes (with `autotune_=false`) may contain
80         // parameters (for example inputs of parallel interleave dataset which
81         // are not in the current cycle). To avoid unrealistic situation
82         // (say `buffer_size=-1` or `parallelism=-1`) in the optimization
83         // computation, if the state value is `kAutotune=-1` (just to indicate
84         // the `SharedState` is tunable), we initialize the parameter value to
85         // be the minimal value of the state.
86         value(state->value == kAutotune ? min : state->value),
87         min(min),
88         max(max),
89         state(std::move(state)) {}
90 
91   // Human-readable name of the parameter.
92   const string name;
93 
94   // Identifies the model value of the parameter. This can be different from
95   // the actual value (e.g. during optimization search).
96   double value;
97 
98   // Identifies the minimum value of the parameter.
99   const double min;
100 
101   // Identifies the maximum value of the parameter.
102   const double max;
103 
104   // Shared state of the parameter.
105   std::shared_ptr<SharedState> state;
106 };
107 
108 std::shared_ptr<Parameter> MakeParameter(const string& name,
109                                          std::shared_ptr<SharedState> state,
110                                          double min, double max);
111 
112 // Abstract representation of a TensorFlow input pipeline node. It collects
113 // information about inputs to this node, processing time spent executing the
114 // node logic, number of elements produced by the node, various other
115 // information (e.g. batch size or execution parallelism).
116 //
117 // Developers of tf.data transformations are not expected to interact with
118 // this class directly. Boiler plate code for creating the abstract
119 // representation of the input pipeline and collecting common information has
120 // been added to the implementation of `DatasetBase` and `DatasetBaseIterator`
121 // respectively.
122 //
123 // In addition, `DatasetBaseIterator` provides wrappers that can be used for
124 // transformation-specific information collection. The `SetMetadata` wrapper
125 // can be used to pass arbitrary metadata to the modeling framework, while the
126 // `StartWork` and `StopWork` wrappers should be used to correctly account for
127 // processing time of multi-threaded transformation that yield the CPU; such
128 // transformations should invoke `StartWork()` when a transformation thread
129 // starts executing (e.g. when created or woken up) and `StopWork()` when a
130 // transformation thread stops executing (e.g. when returning or waiting).
131 class Node {
132  public:
133   // Arguments for `Node` constructor.
134   struct Args {
135     int64 id;
136     string name;
137     std::shared_ptr<Node> output;
138   };
139 
140   using Factory = std::function<std::shared_ptr<Node>(Args)>;
141   using NodeVector = std::vector<std::shared_ptr<Node>>;
142   using NodePairList =
143       std::list<std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>>;
144 
Node(Args args)145   explicit Node(Args args)
146       : id_(args.id),
147         name_(std::move(args.name)),
148         autotune_(true),
149         buffered_bytes_(0),
150         buffered_elements_(0),
151         bytes_consumed_(0),
152         bytes_produced_(0),
153         num_elements_(0),
154         processing_time_(0),
155         record_metrics_(true),
156         metrics_(name_),
157         output_(args.output.get()) {}
158 
~Node()159   virtual ~Node() {
160     // Clear the sub-nodes instead of relying on implicit shared pointer
161     // destructor to avoid potential stack overflow when the tree is deep.
162     std::deque<std::shared_ptr<Node>> queue;
163     {
164       mutex_lock l(mu_);
165       while (inputs_.size() > 0) {
166         queue.push_back(inputs_.front());
167         inputs_.pop_front();
168       }
169     }
170     while (!queue.empty()) {
171       auto node = queue.back();
172       queue.pop_back();
173       {
174         mutex_lock l(node->mu_);
175         while (node->inputs_.size() > 0) {
176           queue.push_back(node->inputs_.front());
177           node->inputs_.pop_front();
178         }
179       }
180     }
181 
182     FlushMetrics();
183   }
184 
185   // Adds an input.
add_input(std::shared_ptr<Node> node)186   void add_input(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_) {
187     mutex_lock l(mu_);
188     inputs_.push_back(node);
189   }
190 
191   // Increments the aggregate processing time by the given delta.
add_processing_time(int64 delta)192   void add_processing_time(int64 delta) TF_LOCKS_EXCLUDED(mu_) {
193     processing_time_ += delta;
194   }
195 
196   // Returns an indication whether autotuning is enabled for this node.
autotune()197   bool autotune() const TF_LOCKS_EXCLUDED(mu_) {
198     return autotune_;
199   }
200 
201   // Returns the number of bytes stored in this node's buffer.
buffered_bytes()202   int64 buffered_bytes() const TF_LOCKS_EXCLUDED(mu_) {
203     return buffered_bytes_;
204   }
205 
206   // Returns the number of elements stored in this node's buffer.
buffered_elements()207   int64 buffered_elements() const TF_LOCKS_EXCLUDED(mu_) {
208     return buffered_elements_;
209   }
210 
211   // Returns the number of bytes consumed by the node.
bytes_consumed()212   int64 bytes_consumed() const TF_LOCKS_EXCLUDED(mu_) {
213     return bytes_consumed_;
214   }
215 
216   // Returns the number of bytes produced by the node.
bytes_produced()217   int64 bytes_produced() const TF_LOCKS_EXCLUDED(mu_) {
218     return bytes_produced_;
219   }
220 
221   // Indicates whether the node has tunable parameters.
has_tunable_parameters()222   bool has_tunable_parameters() const TF_LOCKS_EXCLUDED(mu_) {
223     tf_shared_lock l(mu_);
224     for (const auto& pair : parameters_) {
225       if (pair.second->state->tunable) return true;
226     }
227     return false;
228   }
229 
230   // Returns the unique node ID.
id()231   int64 id() const TF_LOCKS_EXCLUDED(mu_) { return id_; }
232 
233   // Returns the node inputs.
inputs()234   std::list<std::shared_ptr<Node>> inputs() const TF_LOCKS_EXCLUDED(mu_) {
235     tf_shared_lock l(mu_);
236     return inputs_;
237   }
238 
239   // Returns a longer node name that is guaranteed to be unique.
long_name()240   string long_name() const { return strings::StrCat(name_, "(id:", id_, ")"); }
241 
242   // Returns the node name.
name()243   const string& name() const { return name_; }
244 
245   // Returns the number of elements produced by the node.
num_elements()246   int64 num_elements() const TF_LOCKS_EXCLUDED(mu_) {
247     return num_elements_;
248   }
249 
250   // Returns the node output.
output()251   Node* output() const { return output_; }
252 
253   // Returns the parameter value.
parameter_value(const string & name)254   double parameter_value(const string& name) const TF_LOCKS_EXCLUDED(mu_) {
255     tf_shared_lock l(mu_);
256     return parameters_.at(name)->state->value;
257   }
258 
259   // Returns the aggregate processing time.
processing_time()260   int64 processing_time() const TF_LOCKS_EXCLUDED(mu_) {
261     return processing_time_;
262   }
263 
264   // Records that the node consumed the given number of bytes.
record_bytes_consumed(int64 num_bytes)265   void record_bytes_consumed(int64 num_bytes) { bytes_consumed_ += num_bytes; }
266 
267   // Records that the node produced the given number of bytes.
record_bytes_produced(int64 num_bytes)268   void record_bytes_produced(int64 num_bytes) { bytes_produced_ += num_bytes; }
269 
270   // Records the change in this node's buffer.
record_buffer_event(int64 bytes_delta,int64 elements_delta)271   void record_buffer_event(int64 bytes_delta, int64 elements_delta) {
272     buffered_bytes_ += bytes_delta;
273     buffered_elements_ += elements_delta;
274   }
275 
276   // Records that the node produced an element.
record_element()277   void record_element() TF_LOCKS_EXCLUDED(mu_) {
278     num_elements_++;
279   }
280 
281   // Records that a node thread has started executing.
record_start(int64 time_nanos)282   void record_start(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
283     DCHECK_EQ(work_start_, 0);
284     work_start_ = time_nanos;
285   }
286 
287   // Records that a node thread has stopped executing.
record_stop(int64 time_nanos)288   void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
289     // TODO(jsimsa): Use DCHECK_NE(work_start_, 0) here.
290     if (work_start_ != 0) {
291       processing_time_ += time_nanos - work_start_;
292       work_start_ = 0;
293     } else {
294       VLOG(1) << "Encountered a stop event without a matching start event.";
295     }
296   }
297 
298   // Returns whether work is currently being recorded, i.e. whether we are
299   // currently between a `record_start` and a `record_stop`.
is_recording()300   bool is_recording() TF_LOCKS_EXCLUDED(mu_) { return work_start_ > 0; }
301 
302   // Removes an input.
remove_input(std::shared_ptr<Node> input)303   void remove_input(std::shared_ptr<Node> input) TF_LOCKS_EXCLUDED(mu_) {
304     mutex_lock l(mu_);
305     inputs_.remove(input);
306   }
307 
308   // Sets the value that determines whether autotuning is enabled for this node.
set_autotune(bool autotune)309   void set_autotune(bool autotune) TF_LOCKS_EXCLUDED(mu_) {
310     autotune_.store(autotune);
311   }
312 
313   // Given the average time between events when the elements in the buffer are
314   // produced (`producer_time`), the average time between events when elements
315   // in the buffer are consumed (`consumer_time`) and the buffer size, the
316   // method computes the expected time an consumer event will have to wait.
317   //
318   // The wait time is approximated as the product of the probability the buffer
319   // will be empty and the time it takes to produce an element into the buffer.
320   //
321   // The formula used for computing the probability is derived by modeling the
322   // problem as an M/M/1/K queue
323   // (https://en.wikipedia.org/wiki/Birth%E2%80%93death_process#M/M/1/K_queue).
324   //
325   // Collects derivatives of `ComputeWaitTime` w.r.t `producer_time`,
326   // `consumer_time' and `buffer_size` if the corresponding pointers are not
327   // `nullptr`.
328   static double ComputeWaitTime(const double& producer_time,
329                                 const double& consumer_time,
330                                 const double& buffer_size,
331                                 double* producer_time_derivative,
332                                 double* consumer_time_derivative,
333                                 double* buffer_size_derivative);
334 
335   // Collects tunable parameters in the subtree rooted in this node.
336   void CollectTunableParameters(
337       absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const
338       TF_LOCKS_EXCLUDED(mu_);
339 
340   // Returns a human-readable representation of this node.
341   string DebugString() const TF_LOCKS_EXCLUDED(mu_);
342 
343   // Flushes the metrics recorded by this node.
344   void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
345 
346   // Returns the per-element output time for this node and if `gradients` is not
347   // `nullptr`, collects the output time gradient w.r.t. tunable parameters of
348   // the subtree rooted in this node.
349   double OutputTime(absl::flat_hash_map<string, double>* input_times,
350                     absl::flat_hash_map<string, double>* gradients) const
351       TF_LOCKS_EXCLUDED(mu_);
352 
353   // Returns a copy of this node, making a deep copy of its inputs and a
354   // shallow copy of its tunable parameters.
355   //
356   // The purpose for this method is to allow the model optimization logic to
357   // operate over immutable state while allowing concurrent model updates.
358   std::shared_ptr<Node> Snapshot() const TF_LOCKS_EXCLUDED(mu_);
359 
360   // Returns the per-element processing time spent in this node.
361   double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_);
362 
363   // Returns the total number of bytes buffered in all nodes in the subtree for
364   // which autotuning is enabled.
365   double TotalBufferedBytes() const TF_LOCKS_EXCLUDED(mu_);
366 
367   // Collects the total buffer limit of all nodes in the subtree for which
368   // autotuning is enabled. This number represents the amount of memory that
369   // would be used by the subtree nodes if all of their buffers were full.
370   double TotalMaximumBufferedBytes() const TF_LOCKS_EXCLUDED(mu_);
371 
372   // Returns the per-element CPU time spent in the subtree rooted in this node.
373   // If `processing_times` is not `nullptr`, collects the per-element CPU time
374   // spent in each node of the subtree.
375   double TotalProcessingTime(
376       absl::flat_hash_map<string, double>* processing_times)
377       TF_LOCKS_EXCLUDED(mu_);
378 
379   // Recursively produces a proto for this node and its subtree.
380   virtual Status ToProto(ModelProto::Node* node_proto) const;
381 
382   // Recursively restores a node and its subtree from the proto.
383   static Status FromProto(ModelProto::Node node_proto,
384                           std::shared_ptr<Node> output,
385                           std::shared_ptr<Node>* node);
386 
387  protected:
388   // Used for (incrementally) recording metrics. The class is thread-safe.
389   class Metrics {
390    public:
Metrics(const string & name)391     explicit Metrics(const string& name)
392         : bytes_consumed_counter_(metrics::GetTFDataBytesConsumedCounter(name)),
393           bytes_produced_counter_(metrics::GetTFDataBytesProducedCounter(name)),
394           num_elements_counter_(metrics::GetTFDataElementsCounter(name)),
395           recorded_bytes_consumed_(0),
396           recorded_bytes_produced_(0),
397           recorded_num_elements_(0) {}
398 
399     // Expects the total number of bytes consumed and records the delta since
400     // last invocation.
record_bytes_consumed(int64 total_bytes)401     void record_bytes_consumed(int64 total_bytes) {
402       int64 delta =
403           total_bytes - recorded_bytes_consumed_.exchange(total_bytes);
404       bytes_consumed_counter_->IncrementBy(delta);
405     }
406 
407     // Expects the total number of bytes produced and records the delta since
408     // last invocation.
record_bytes_produced(int64 total_bytes)409     void record_bytes_produced(int64 total_bytes) {
410       int64 delta =
411           total_bytes - recorded_bytes_produced_.exchange(total_bytes);
412       bytes_produced_counter_->IncrementBy(delta);
413     }
414 
415     // Expects the total number of elements produced and records the delta since
416     // last invocation.
record_num_elements(int64 total_elements)417     void record_num_elements(int64 total_elements) {
418       int64 delta =
419           total_elements - recorded_num_elements_.exchange(total_elements);
420       num_elements_counter_->IncrementBy(delta);
421     }
422 
423    private:
424     monitoring::CounterCell* const bytes_consumed_counter_;
425     monitoring::CounterCell* const bytes_produced_counter_;
426     monitoring::CounterCell* const num_elements_counter_;
427     std::atomic<int64> recorded_bytes_consumed_;
428     std::atomic<int64> recorded_bytes_produced_;
429     std::atomic<int64> recorded_num_elements_;
430   };
431 
432   // Returns the number of inputs.
num_inputs()433   int64 num_inputs() const TF_SHARED_LOCKS_REQUIRED(mu_) {
434     int64 num_inputs = 0;
435     for (auto& input : inputs_) {
436       // Inputs for which autotuning is disabled are excluded.
437       if (input->autotune()) {
438         ++num_inputs;
439       }
440     }
441     return num_inputs;
442   }
443 
444   // Creates a clone of this node.
445   virtual std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const
446       TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
447 
448   // Returns the average size of an element buffered in this node.
449   double AverageBufferedElementSize() const TF_SHARED_LOCKS_REQUIRED(mu_);
450 
451   // Returns the sum of per-element output time for the tunable inputs of this
452   // node.
453   double OutputTimeForInputs(
454       const absl::flat_hash_map<string, double>& output_times) const
455       TF_SHARED_LOCKS_REQUIRED(mu_);
456 
457   // Returns the sum of output time gradient w.r.t. input time for the tunable
458   // inputs of this node.
459   double OutputTimeGradientsForInputs(
460       const absl::flat_hash_map<string, double>& output_time_gradients) const
461       TF_SHARED_LOCKS_REQUIRED(mu_);
462 
463   // Computes the input time for this node and stores it in `input_times`.
464   virtual void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
465       const TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
466 
467   // Computes the per-element output time for this node and stores it in
468   // `output_times`. If `gradients` is not `nullptr`, computes the output time
469   // gradient w.r.t. tunable parameters of the subtree rooted in this node and
470   // stores it in `gradients`, also computes the output time gradient w.r.t.
471   // input time and stores it in `output_time_gradients`.
472   virtual void OutputTimeLocked(
473       const absl::flat_hash_map<string, double>& input_times,
474       absl::flat_hash_map<string, double>* gradients,
475       absl::flat_hash_map<string, double>* output_times,
476       absl::flat_hash_map<string, double>* output_time_gradients) const
477       TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
478 
479   // Returns the sum of per-element processing time for the inputs of this node
480   // by adding values for input nodes in `total_processing_times`. Processing
481   // time for a given input is a weighted combination of a statistic based on
482   // history of input processing time and the actual time. This is done to
483   // improve accuracy of processing time estimation for newly created inputs.
484   //
485   // Uniform distribution of per-element processing times across different
486   // inputs is assumed.
487   double TotalProcessingTimeForInputs(
488       const absl::flat_hash_map<string, double>& total_processing_times)
489       TF_SHARED_LOCKS_REQUIRED(mu_);
490 
491   // Returns the per-element processing time spent in this node.
492   double SelfProcessingTimeLocked() const TF_SHARED_LOCKS_REQUIRED(mu_);
493 
494   // Computes the per-element CPU time spent in the subtree rooted in this node
495   // and stores it in `total_processing_times`. If `processing_times` is not
496   // `nullptr`, collects the per-element CPU time spent in each node of the
497   // subtree.
498   virtual void TotalProcessingTimeLocked(
499       absl::flat_hash_map<string, double>* processing_times,
500       absl::flat_hash_map<string, double>* total_processing_times)
501       TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
502 
503   // Returns a vector of nodes of the subtree rooted in this node. The nodes are
504   // either in breadth-first search or reverse breadth-first search order
505   // depending on the `order` argument. The nodes are collected based on the
506   // results of the `collect_node` predicate: if the predicate returns `false`
507   // for a given node, then the subtree rooted in this node is excluded. The
508   // root node itself is not collected.
509   NodeVector CollectNodes(TraversalOrder order,
510                           bool collect_node(const std::shared_ptr<Node>)) const
511       TF_SHARED_LOCKS_REQUIRED(mu_);
512 
513   // Collect tunable parameters on the nodes which have recorded elements.
514   void CollectTunableParametersHelper(
515       absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const
516       TF_SHARED_LOCKS_REQUIRED(mu_);
517 
518   // Build up debug string for the node and store in the debug strings map.
519   void DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings)
520       const TF_SHARED_LOCKS_REQUIRED(mu_);
521 
522   // Copy the node and add the (input, copy) pairs to the NodePairList.
523   std::shared_ptr<Node> SnapshotHelper(std::shared_ptr<Node> cloned_output,
524                                        NodePairList* node_pairs) const;
525 
526   // Compute total buffered bytes for the node and store in the total bytes map.
527   void TotalBufferedBytesHelper(
528       absl::flat_hash_map<string, double>* total_bytes) const
529       TF_SHARED_LOCKS_REQUIRED(mu_);
530 
531   // Compute total maximum buffered bytes for the node and store in the total
532   // bytes map.
533   void TotalMaximumBufferedBytesHelper(
534       absl::flat_hash_map<string, double>* total_bytes) const
535       TF_SHARED_LOCKS_REQUIRED(mu_);
536 
537   // Compute and return the maximum buffered bytes on the node itself. By
538   // default non-tunable nodes are assumed not to buffer any bytes, so the
539   // tunable nodes as subclasses are expected to override this method to ensure
540   // that the optimization algorithm respects the memory budget.
541   virtual double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_);
542 
543   // Restores node from the proto. Note that this is not done recursively, i.e.
544   // input nodes are not restored.
545   static Status FromProtoHelper(ModelProto::Node node_proto,
546                                 std::shared_ptr<Node> node);
547 
548   // Stores the time passed to the last call to `Node::record_start()` on the
549   // current thread.
550   //
551   // NOTE: This thread-local variable is shared between all instances of `Node`
552   // on which the same thread calls `record_start()` or `record_stop()`. It
553   // relies on the invariant that at most one `Node` can be "active" on a
554   // particular thread at any time. Therefore if `n->record_start()` is called
555   // on thread `t`, then `n->record_stop()` must be called before another call
556   // to `Node::record_start()` (for any node).
557   static thread_local int64 work_start_;  // Will be initialized to zero.
558 
559   mutable mutex mu_;
560   const int64 id_;
561   const string name_;
562 
563   // Indicates whether the subtree rooted in this node should be included in
564   // autotuning. In particular, if this is `false`, then the subtree is excluded
565   // from computation of output time and processing time.
566   std::atomic<bool> autotune_;
567   std::atomic<int64> buffered_bytes_;
568   std::atomic<int64> buffered_elements_;
569   std::atomic<int64> bytes_consumed_;
570   std::atomic<int64> bytes_produced_;
571   std::atomic<int64> num_elements_;
572   std::atomic<int64> processing_time_;
573   std::atomic<bool> record_metrics_;
574   Metrics metrics_;
575   absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_
576       TF_GUARDED_BY(mu_);
577 
578   // Statistic of inputs processing time history.
579   double input_processing_time_sum_ = 0.0L;
580   int64 input_processing_time_count_ = 0;
581 
582   // Inputs of this node. These can represent an iterator created from the input
583   // dataset but also other input iterators (e.g. created by the user-defined
584   // functions of `flat_map` or `interleave`).
585   std::list<std::shared_ptr<Node>> inputs_ TF_GUARDED_BY(mu_);
586 
587   // The reference to the output node is not owned so that deletion of a
588   // node results in recursive deletion of the subtree rooted in the node.
589   Node* const output_;
590 };
591 
592 // InterleaveMany is used to model datasets whose inputs are used to create
593 // datasets whose elements are then interleaved.
594 std::shared_ptr<Node> MakeInterleaveManyNode(Node::Args args);
595 
596 // AsyncInterleaveMany nodes are the asynchronous version of InterleaveMany
597 // nodes.
598 std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
599     Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters);
600 
601 // KnownMany nodes model datasets that synchronously consume known number of
602 // input element per output element.
603 std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio);
604 
605 // AsyncKnownRatio nodes are the asynchronous version of KnownRate nodes.
606 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
607     Node::Args args, double ratio, double memory_ratio,
608     std::vector<std::shared_ptr<Parameter>> parameters);
609 
610 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
611     Node::Args args, double ratio,
612     std::vector<std::shared_ptr<Parameter>> parameters);
613 
614 // Source nodes represent data sources.
615 std::shared_ptr<Node> MakeSourceNode(Node::Args args);
616 
617 // UnknownMany nodes represent datasets that synchronously consume an
618 // unknown number of input elements per output.
619 //
620 // Unlike KnownRatio nodes which expect the ratio between inputs and outputs is
621 // specified as a parameter, UnknownRatio estimates the ratio empirically.
622 std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args);
623 
624 // Unknown nodes represent datasets for which we do not have a model. It acts
625 // as pass-through between inputs and output.
626 std::shared_ptr<Node> MakeUnknownNode(Node::Args args);
627 
628 // Abstract representation of a TensorFlow input pipeline that can be used
629 // for collecting runtime information and optimizing performance. It collects
630 // runtime information about execution of the input pipeline that is used to
631 // create a performance model, which is in turn used to identify optimal values
632 // of tunable parameters.
633 //
634 // Developers of tf.data transformations are not expected to interact with this
635 // class directly. Boiler plate code for creating the abstract representation of
636 // the input pipeline and collecting runtime information has been added to the
637 // implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
638 class Model {
639  public:
640   using OptimizationParams = ModelProto::OptimizationParams;
641 
642   // Creates a new model.
Model()643   Model()
644       : collect_resource_usage_(false),
645         optimization_period_ms_(kOptimizationPeriodMinMs) {
646     const char* save_dir = std::getenv("TF_DATA_AUTOTUNE_DEBUG_DIR");
647     if (save_dir) {
648       save_dir_ = string(save_dir);
649     }
650   }
651 
~Model()652   ~Model() {
653     if (!save_dir_.empty()) {
654       save_thread_cancelled_ = true;
655       save_cond_var_.notify_all();
656     }
657   }
658 
659   // Indicates whether to collect resource usage.
collect_resource_usage()660   bool collect_resource_usage() const { return collect_resource_usage_; }
661 
662   // Returns a pointer to the model's output node.
output()663   const std::shared_ptr<Node> output() {
664     mutex_lock l(mu_);
665     return output_;
666   }
667 
668   // Adds a node with the given name and given parent.
669   void AddNode(Node::Factory factory, const string& name,
670                std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node)
671       TF_LOCKS_EXCLUDED(mu_);
672 
673   // Uses the given algorithm and resource budgets to periodically perform the
674   // autotuning optimization.
675   //
676   // To terminate the execution of the optimization loop, the caller needs to
677   // invoke `cancellation_mgr->StartCancel()`.
678   Status OptimizeLoop(AutotuneAlgorithm algorithm, int64 cpu_budget,
679                       int64 ram_budget, CancellationManager* cancellation_mgr);
680 
681   // Uses the given algorithm and resource budgets to perform the autotuning
682   // optimization.
683   void Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget, int64 ram_budget,
684                 double model_input_time);
685 
686   // Removes the given node.
687   void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_);
688 
689   // Produces a proto for this model.
690   Status ToProto(ModelProto* model_proto);
691 
692   // Restores a model from the proto.
693   static Status FromProto(ModelProto model_proto,
694                           std::unique_ptr<Model>* model);
695 
696   // Saves this model with a given snapshot and its optimization parameters to a
697   // file. Note that the file directory must already exist.
698   Status Save(const string& fname, std::shared_ptr<Node> snapshot,
699               const OptimizationParams& optimization_params);
700 
701   // Loads a model and its optimization parameters from a file with the given
702   // name.
703   static Status Load(const string& fname, std::unique_ptr<Model>* model,
704                      OptimizationParams* optimization_params);
705 
706  private:
707   static constexpr int64 kOptimizationPeriodMinMs = 10;
708   static constexpr int64 kOptimizationPeriodMaxMs =
709       60 * EnvTime::kSecondsToMillis;
710 
711   // Maximum number of optimization snapshots kept in a buffer for saving.
712   static constexpr int64 kMaxNumBufferedOptimizeArgs = 100;
713 
714   // Collects tunable parameters in the tree rooted in the given node, returning
715   // a mapping from a (unique) node name to a tunable parameter.
716   absl::flat_hash_map<string, std::shared_ptr<Parameter>>
717   CollectTunableParameters(std::shared_ptr<Node> node);
718 
719   // Flushes metrics recorded by the model.
720   void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
721 
722   // This optimization algorithm starts by setting all tunable parallelism
723   // parameters to the minimum value. It then repeatedly identifies the
724   // parameter whose increase in parallelism decreases the output time the most.
725   // This process is repeated until all parameters reach their maximum values or
726   // the projected output time is less than or equal to the processing time
727   // needed to produce an element divided by CPU budget.
728   void OptimizeHillClimb(std::shared_ptr<Node> snapshot,
729                          const OptimizationParams& optimization_params);
730 
731   // This optimization algorithm starts by setting all tunable parallelism
732   // parameters to the minimum value. It then improves current parameters by
733   // making a step in the direction opposite to the gradient of `OutputTime` and
734   // projecting resulting values on the feasible intervals. Improvement step is
735   // repeated until either the output time improvement is smaller than threshold
736   // value or the output time is less than the processing time needed to produce
737   // an element divided by CPU budget.
738   void OptimizeGradientDescent(std::shared_ptr<Node> snapshot,
739                                const OptimizationParams& optimization_params);
740 
741   // Collects the output time and if `gradients` is not `nullptr`, the output
742   // time gradient w.r.t. tunable parameters of the subtree rooted in the given
743   // node.
744   double OutputTime(std::shared_ptr<Node> node, double model_input_time,
745                     absl::flat_hash_map<string, double>* gradients);
746 
747   // Determines if we should stop the gradient descent optimization iterations
748   // based on number of increasable parameters, CPU budget, RAM budget and
749   // current resource usage.
750   bool ShouldStop(
751       int64 cpu_budget, int64 ram_budget,
752       const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
753       const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
754           parallelism_parameters,
755       const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
756           buffer_size_parameters,
757       std::shared_ptr<Node> snapshot, bool* cpu_budget_reached);
758 
759   // Collects the processing time for the given node.
760   double TotalProcessingTime(std::shared_ptr<Node> node);
761 
762   // Collects the total number of bytes buffered in all nodes in the subtree
763   // rooted in the given node for which autotuning is enabled.
764   double TotalBufferedBytes(std::shared_ptr<Node> node);
765 
766   // Collects the total buffer limit of all nodes in the subtree rooted in the
767   // given node for which autotuning is enabled. This number represents the
768   // amount of memory that would be used by the subtree nodes if all of their
769   // buffers were full.
770   double TotalMaximumBufferedBytes(std::shared_ptr<Node> node);
771 
772   // Starts a model saving thread if it hasn't started yet.
773   Status EnsureSaveLoopThreadStarted();
774 
775   // Periodically saves the state of optimization that is kept in
776   // `save_buffer_`.
777   //
778   // The saving loop is terminated when the model is destroyed.
779   Status SaveLoop();
780 
781   // Used for coordination between different input pipeline threads. Exclusive
782   // access is required only when adding or removing nodes. Concurrent access to
783   // existing nodes is protected by a node mutex.
784   mutex mu_;
785   // Used for coordinating the optimization loop and model modifications.
786   condition_variable optimize_cond_var_;
787   int64 id_counter_ TF_GUARDED_BY(mu_) = 1;
788   std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_);
789 
790   // Indicates whether the modeling framework should collect resource usage
791   // (e.g. CPU, memory). The logic for collecting this information assumes that
792   // the collection is not repeatedly disabled and enabled. As a consequence,
793   // the implementation starts collecting resource usage when it encounters a
794   // tunable parameter (because the information is used for tuning the value of
795   // the parameter) and never stops.
796   std::atomic<bool> collect_resource_usage_;
797 
798   // Determines the time the optimization loop should wait between
799   // running optimizations.
800   int64 optimization_period_ms_ TF_GUARDED_BY(mu_);
801 
802   // Thread that runs the model saving loop.
803   std::unique_ptr<Thread> save_thread_ TF_GUARDED_BY(mu_);
804 
805   // Used for coordinating the saving loop and model optimization.
806   condition_variable save_cond_var_;
807 
808   // Indicates whether the save thread is cancelled.
809   bool save_thread_cancelled_ = false;
810 
811   // Contains path to the model saving directory if saving is enabled, empty
812   // otherwise.
813   string save_dir_;
814 
815   // Contains pairs of model snapshots and optimization parameters to be saved
816   // if model saving is enabled, empty otherwise. Buffer elements are pushed by
817   // `OptimizeLoop` and popped by `SaveLoop`.
818   std::deque<std::pair<std::shared_ptr<Node>, OptimizationParams>> save_buffer_
819       TF_GUARDED_BY(mu_);
820 };
821 
822 }  // namespace model
823 }  // namespace data
824 }  // namespace tensorflow
825 
826 #endif  // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
827