• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/model.h"
17 
18 #include <memory>
19 
20 #include "absl/time/clock.h"
21 #include "tensorflow/core/framework/cancellation.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 
25 namespace tensorflow {
26 namespace data {
27 namespace model {
28 
29 constexpr int64 Model::kOptimizationPeriodMinMs;
30 constexpr int64 Model::kOptimizationPeriodMaxMs;
31 
32 namespace {
33 
34 // Helper function for node traversal that doesn't skip any nodes.
IsAnyNode(const std::shared_ptr<Node> node)35 inline bool IsAnyNode(const std::shared_ptr<Node> node) { return true; }
36 
37 // Helper function for node traversal that filters out nodes for which
38 // autotuning is disabled.
IsAutotuneNode(const std::shared_ptr<Node> node)39 inline bool IsAutotuneNode(const std::shared_ptr<Node> node) {
40   return node->autotune();
41 }
42 
43 // Wrapper for the square function to reduce verbosity.
Square(double x)44 inline double Square(double x) { return x * x; }
45 
46 // Collects "essential" parallelism parameters and buffer size parameters in the
47 // tree rooted in the given node. Which parallelism parameters are essential is
48 // determined by the relative processing time spent in the corresponding
49 // transformation. The collected parameters are returned via maps that map node
50 // names to their respective parameters.
CollectParameters(std::shared_ptr<Node> node,const absl::flat_hash_map<string,std::shared_ptr<Parameter>> & parameters,absl::flat_hash_map<string,std::shared_ptr<Parameter>> * parallelism_parameters,absl::flat_hash_map<string,std::shared_ptr<Parameter>> * buffer_size_parameters)51 inline void CollectParameters(
52     std::shared_ptr<Node> node,
53     const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
54     absl::flat_hash_map<string, std::shared_ptr<Parameter>>*
55         parallelism_parameters,
56     absl::flat_hash_map<string, std::shared_ptr<Parameter>>*
57         buffer_size_parameters) {
58   // Parallelism parameter is considered to be essential if the corresponding
59   // transformations's processing time is greater than essential rate times the
60   // average transformation self processing time.
61   constexpr double kEssentialRate = 0.3L;
62 
63   absl::flat_hash_map<string, double> processing_times;
64   double processing_time = node->TotalProcessingTime(&processing_times);
65   double uniform_share =
66       processing_time / static_cast<double>(processing_times.size());
67   for (auto& pair : parameters) {
68     if (pair.second->name == kParallelism &&
69         processing_times[pair.first] > kEssentialRate * uniform_share) {
70       parallelism_parameters->insert(pair);
71     } else if (pair.second->name == kBufferSize) {
72       buffer_size_parameters->insert(pair);
73     }
74   }
75 }
76 
77 // Applies the gradient descent method once and updates the parameter values. If
78 // the new value is out of the range, bound it within the range between the
79 // minimal and maximum values.
UpdateParameterValues(const absl::flat_hash_map<string,double> & gradients,absl::flat_hash_map<string,std::shared_ptr<Parameter>> * parameters)80 inline void UpdateParameterValues(
81     const absl::flat_hash_map<string, double>& gradients,
82     absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) {
83   // Gradient descent step size.
84   constexpr double kDescentStep = 0.1L;
85   double new_value;
86 
87   double max_abs_derivative = 1.0;
88   for (auto& pair : *parameters) {
89     if (std::round(pair.second->value) != pair.second->max) {
90       auto* gradient = gtl::FindOrNull(gradients, pair.first);
91       if (gradient) {
92         max_abs_derivative = std::max(max_abs_derivative, std::abs(*gradient));
93       }
94     }
95   }
96   for (auto& pair : *parameters) {
97     auto* gradient = gtl::FindOrNull(gradients, pair.first);
98     if (gradient) {
99       new_value =
100           pair.second->value - kDescentStep * (*gradient) / max_abs_derivative;
101       // Projection on a feasible interval.
102       if (new_value > pair.second->max) {
103         pair.second->value = pair.second->max;
104       } else if (new_value < pair.second->min) {
105         pair.second->value = pair.second->min;
106       } else {
107         pair.second->value = new_value;
108       }
109     }
110   }
111 }
112 
113 // Copies the parameter values (which are for optimization tuning) and updates
114 // the state values (which are for the input pipeline to follow).
UpdateStateValues(absl::flat_hash_map<string,std::shared_ptr<Parameter>> * parameters)115 inline void UpdateStateValues(
116     absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) {
117   VLOG(2) << "Number of tunable parameters: " << parameters->size();
118   for (auto& pair : *parameters) {
119     auto& parameter = pair.second;
120     VLOG(2) << "Setting tunable parameter " << pair.first << " to "
121             << parameter->value;
122     mutex_lock l(*parameter->state->mu);
123     parameter->state->value = parameter->value;
124     parameter->state->cond_var->notify_all();
125   }
126 }
127 
128 // The first input of InterleaveMany corresponds to the input dataset whose
129 // elements are used to create the (derived) input datasets whose elements are
130 // interleaved as output.
131 //
132 // TODO(jsimsa): model the first input
133 class InterleaveMany : public Node {
134  public:
135   using Node::Node;
136 
~InterleaveMany()137   virtual ~InterleaveMany() {}
138 
139  protected:
Clone(std::shared_ptr<Node> output) const140   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
141       TF_SHARED_LOCKS_REQUIRED(mu_) {
142     return std::make_shared<InterleaveMany>(
143         Args{id_, name_, std::move(output)});
144   }
145 
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const146   void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
147       const override TF_SHARED_LOCKS_REQUIRED(mu_) {
148     double inherited_input_time;
149     if (output_) {
150       inherited_input_time = (*input_times)[output_->long_name()];
151     } else {
152       inherited_input_time = (*input_times)[kModelInputTimeKey];
153     }
154 
155     if (num_inputs() <= 1) {
156       (*input_times)[long_name()] = inherited_input_time;
157       return;
158     }
159     // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
160     // input time for InterleaveMany node to call one of the
161     // `(num_inputs() - 1)` input nodes (except first input) to return an
162     // element. Regardless of the `block_length` parameter of InterleaveMany
163     // node, the average input time for any of the `(num_inputs() - 1)` input
164     // nodes to be called is computed as:
165     double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
166                         static_cast<double>(num_inputs() - 1);
167     (*input_times)[long_name()] = input_time;
168   }
169 
170   // The output time is the sum of the self processing time and the average
171   // output time of inputs comprising the interleave "cycle".
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const172   void OutputTimeLocked(
173       const absl::flat_hash_map<string, double>& input_times,
174       absl::flat_hash_map<string, double>* gradients,
175       absl::flat_hash_map<string, double>* output_times,
176       absl::flat_hash_map<string, double>* output_time_gradients) const override
177       TF_SHARED_LOCKS_REQUIRED(mu_) {
178     double self_processing_time = SelfProcessingTimeLocked();
179     if (num_inputs() <= 1) {
180       (*output_times)[long_name()] = self_processing_time;
181       if (gradients) {
182         for (const auto& node :
183              CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
184           gradients->erase(node->long_name());
185         }
186       }
187       return;
188     }
189 
190     double inputs_output_time =
191         (OutputTimeForInputs(*output_times) -
192          (*output_times)[inputs_.front()->long_name()]) /
193         static_cast<double>(num_inputs() - 1);
194     if (gradients) {
195       for (const auto& node :
196            CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
197         auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
198         if (gradient) {
199           *gradient /= static_cast<double>(num_inputs() - 1);
200         }
201       }
202 
203       (*output_time_gradients)[long_name()] =
204           OutputTimeGradientsForInputs(*output_time_gradients) -
205           (*output_time_gradients)[inputs_.front()->long_name()];
206 
207       // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
208       // first input equal to 0 since its output time is excluded from
209       // computations.
210       absl::flat_hash_map<string, std::shared_ptr<Parameter>>
211           first_input_parameters;
212       inputs_.front()->CollectTunableParameters(&first_input_parameters);
213       for (auto& pair : first_input_parameters) {
214         (*gradients)[pair.first] = 0.0L;
215       }
216     }
217     (*output_times)[long_name()] = self_processing_time + inputs_output_time;
218   }
219 
220   // The processing time is the sum of the self processing time and the average
221   // processing time of inputs comprising the interleave "cycle".
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)222   void TotalProcessingTimeLocked(
223       absl::flat_hash_map<string, double>* processing_times,
224       absl::flat_hash_map<string, double>* total_processing_times) override
225       TF_SHARED_LOCKS_REQUIRED(mu_) {
226     double self_processing_time = SelfProcessingTimeLocked();
227     if (processing_times) {
228       (*processing_times)[long_name()] = self_processing_time;
229     }
230     if (num_inputs() <= 1) {
231       (*total_processing_times)[long_name()] = self_processing_time;
232       return;
233     }
234     double inputs_processing_time =
235         (TotalProcessingTimeForInputs(*total_processing_times) -
236          (*total_processing_times)[inputs_.front()->long_name()]) /
237         static_cast<double>(num_inputs() - 1);
238     (*total_processing_times)[long_name()] =
239         self_processing_time + inputs_processing_time;
240   }
241 
ToProto(ModelProto::Node * node_proto) const242   Status ToProto(ModelProto::Node* node_proto) const {
243     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
244     node_proto->set_node_class(NodeClass::INTERLEAVE_MANY);
245     return Status::OK();
246   }
247 };
248 
249 // The first input of AsyncInterleaveMany corresponds to the input dataset whose
250 // elements are used to create the (derived) input datasets whose elements are
251 // interleaved as output.
252 //
253 // TODO(jsimsa): model the first input
254 class AsyncInterleaveMany : public Node {
255  public:
AsyncInterleaveMany(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)256   AsyncInterleaveMany(Node::Args args,
257                       std::vector<std::shared_ptr<Parameter>> parameters)
258       : Node(args) {
259     for (auto& parameter : parameters) {
260       parameters_[parameter->name] = std::move(parameter);
261     }
262   }
263 
~AsyncInterleaveMany()264   virtual ~AsyncInterleaveMany() {}
265 
266  protected:
Clone(std::shared_ptr<Node> output) const267   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
268       TF_SHARED_LOCKS_REQUIRED(mu_) {
269     std::vector<std::shared_ptr<Parameter>> parameters;
270     for (auto& pair : parameters_) {
271       parameters.push_back(pair.second);
272     }
273     return std::make_shared<AsyncInterleaveMany>(
274         Args{id_, name_, std::move(output)}, parameters);
275   }
276 
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const277   void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
278       const override TF_SHARED_LOCKS_REQUIRED(mu_) {
279     double inherited_input_time;
280     if (output_) {
281       inherited_input_time = (*input_times)[output_->long_name()];
282     } else {
283       inherited_input_time = (*input_times)[kModelInputTimeKey];
284     }
285 
286     if (num_inputs() <= 1) {
287       (*input_times)[long_name()] = inherited_input_time;
288       return;
289     }
290     // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
291     // input time for AsyncInterleaveMany node to call one of the
292     // `(num_inputs() - 1)` input nodes (except first input) to return an
293     // element. Regardless of the `block_length` parameter of
294     // AsyncInterleaveMany node, the average input time for any of the
295     // `(num_inputs() - 1)` input nodes to be called is computed as:
296     double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
297                         static_cast<double>(num_inputs() - 1);
298     (*input_times)[long_name()] = input_time;
299   }
300 
301   // The output time is the sum of self processing time and expected wait time
302   // from the buffer model estimated using
303   // `ComputeWaitTime(producer_time, consumer_time, parallelism, ...)`, where
304   // `producer_time` is the average output time of inputs comprising the
305   // interleave "cycle" divided by `parallelism`, `consumer_time` is the
306   // `input_time` specified through `input_times` divided by `num_inputs() - 1`,
307   // and if the node has parallelism parameter, then `buffer_size` is derived
308   // from `parallelism`.
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const309   void OutputTimeLocked(
310       const absl::flat_hash_map<string, double>& input_times,
311       absl::flat_hash_map<string, double>* gradients,
312       absl::flat_hash_map<string, double>* output_times,
313       absl::flat_hash_map<string, double>* output_time_gradients) const override
314       TF_SHARED_LOCKS_REQUIRED(mu_) {
315     double self_processing_time = SelfProcessingTimeLocked();
316     if (num_inputs() <= 1) {
317       (*output_times)[long_name()] = self_processing_time;
318       if (gradients) {
319         for (const auto& node :
320              CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
321           gradients->erase(node->long_name());
322         }
323       }
324       return;
325     }
326 
327     double output_time, wait_time, consumer_time, producer_time;
328     double input_time = input_times.at(long_name());
329     consumer_time = input_time / static_cast<double>(num_inputs() - 1);
330     double parallelism = num_inputs() - 1;  // default to cycle length
331     auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
332     if (parameter) {
333       parallelism = std::min(parallelism, (*parameter)->value);
334     }
335     double output_time_for_inputs =
336         OutputTimeForInputs(*output_times) -
337         (*output_times)[inputs_.front()->long_name()];
338     producer_time = output_time_for_inputs /
339                     static_cast<double>(num_inputs() - 1) / parallelism;
340 
341     if (gradients) {
342       double producer_time_der = 0.0L;
343       double consumer_time_der = 0.0L;
344       double buffer_size_der = 0.0L;
345       wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
346                                   &producer_time_der, &consumer_time_der,
347                                   &buffer_size_der);
348       double inputs_time_der_sum =
349           OutputTimeGradientsForInputs(*output_time_gradients);
350       (*output_time_gradients)[long_name()] =
351           consumer_time_der +
352           producer_time_der * inputs_time_der_sum / parallelism;
353 
354       for (const auto& node :
355            CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
356         auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
357         if (gradient) {
358           *gradient *= (producer_time_der /
359                         static_cast<double>(num_inputs() - 1) / parallelism);
360         }
361       }
362 
363       // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
364       // first input equal to 0 since its output time is excluded from
365       // computations.
366       absl::flat_hash_map<string, std::shared_ptr<Parameter>>
367           first_input_parameters;
368       inputs_.front()->CollectTunableParameters(&first_input_parameters);
369       for (auto& pair : first_input_parameters) {
370         (*gradients)[pair.first] = 0.0L;
371       }
372       // Add derivative w.r.t. own parallelism parameter.
373       if (parameter && (*parameter)->state->tunable) {
374         (*gradients)[long_name()] =
375             buffer_size_der - producer_time_der * producer_time / parallelism;
376       }
377     } else {
378       wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
379                                   /*producer_time_derivative=*/nullptr,
380                                   /*consumer_time_derivative=*/nullptr,
381                                   /*buffer_size_derivative=*/nullptr);
382     }
383     output_time = self_processing_time + wait_time;
384     (*output_times)[long_name()] = output_time;
385   }
386 
387   // The processing time is the sum of the self processing time and the average
388   // processing time of inputs comprising the interleave "cycle".
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)389   void TotalProcessingTimeLocked(
390       absl::flat_hash_map<string, double>* processing_times,
391       absl::flat_hash_map<string, double>* total_processing_times) override
392       TF_SHARED_LOCKS_REQUIRED(mu_) {
393     double self_processing_time = SelfProcessingTimeLocked();
394     if (processing_times) {
395       (*processing_times)[long_name()] = self_processing_time;
396     }
397     if (num_inputs() <= 1) {
398       (*total_processing_times)[long_name()] = self_processing_time;
399       return;
400     }
401     double inputs_processing_time =
402         (TotalProcessingTimeForInputs(*total_processing_times) -
403          (*total_processing_times)[inputs_.front()->long_name()]) /
404         static_cast<double>(num_inputs() - 1);
405     (*total_processing_times)[long_name()] =
406         self_processing_time + inputs_processing_time;
407   }
408 
MaximumBufferedBytes() const409   double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
410     double result = 0;
411     auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
412     if (parameter) {
413       result += (*parameter)->value * AverageBufferedElementSize();
414     }
415     return result;
416   }
417 
ToProto(ModelProto::Node * node_proto) const418   Status ToProto(ModelProto::Node* node_proto) const {
419     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
420     node_proto->set_node_class(NodeClass::ASYNC_INTERLEAVE_MANY);
421     return Status::OK();
422   }
423 };
424 
425 class KnownRatio : public Node {
426  public:
KnownRatio(Node::Args args,double ratio)427   KnownRatio(Node::Args args, double ratio) : Node(args), ratio_(ratio) {}
428 
~KnownRatio()429   virtual ~KnownRatio() {}
430 
431  protected:
Clone(std::shared_ptr<Node> output) const432   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
433       TF_SHARED_LOCKS_REQUIRED(mu_) {
434     return std::make_shared<KnownRatio>(Args{id_, name_, std::move(output)},
435                                         ratio_);
436   }
437 
438   // The input time is the sum of inherited input time and self processing time,
439   // divided by `ratio_`.
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const440   void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
441       const override TF_SHARED_LOCKS_REQUIRED(mu_) {
442     double inherited_input_time;
443     if (output_) {
444       inherited_input_time = (*input_times)[output_->long_name()];
445     } else {
446       inherited_input_time = (*input_times)[kModelInputTimeKey];
447     }
448 
449     if (ratio_ == 0) {
450       (*input_times)[long_name()] = inherited_input_time;
451       return;
452     }
453     double input_time =
454         (inherited_input_time + SelfProcessingTimeLocked()) / ratio_;
455     (*input_times)[long_name()] = input_time;
456   }
457 
458   // The output time is the sum of the self processing time and the product of
459   // `ratio_` and the sum of output times of inputs.
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const460   void OutputTimeLocked(
461       const absl::flat_hash_map<string, double>& input_times,
462       absl::flat_hash_map<string, double>* gradients,
463       absl::flat_hash_map<string, double>* output_times,
464       absl::flat_hash_map<string, double>* output_time_gradients) const override
465       TF_SHARED_LOCKS_REQUIRED(mu_) {
466     double self_processing_time = SelfProcessingTimeLocked();
467     if (ratio_ == 0) {
468       (*output_times)[long_name()] = self_processing_time;
469       if (gradients) {
470         for (const auto& node :
471              CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
472           gradients->erase(node->long_name());
473         }
474       }
475       return;
476     }
477     if (gradients) {
478       for (const auto& node :
479            CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
480         auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
481         if (gradient) {
482           *gradient *= ratio_;
483         }
484       }
485       (*output_time_gradients)[long_name()] =
486           OutputTimeGradientsForInputs(*output_time_gradients);
487     }
488     double inputs_output_time = ratio_ * OutputTimeForInputs(*output_times);
489     (*output_times)[long_name()] = self_processing_time + inputs_output_time;
490   }
491 
492   // The processing time is the sum of the self processing time and the product
493   // of `ratio_` and the sum of processing times of inputs.
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)494   void TotalProcessingTimeLocked(
495       absl::flat_hash_map<string, double>* processing_times,
496       absl::flat_hash_map<string, double>* total_processing_times) override
497       TF_SHARED_LOCKS_REQUIRED(mu_) {
498     double self_processing_time = SelfProcessingTimeLocked();
499     if (processing_times) {
500       (*processing_times)[long_name()] = self_processing_time;
501     }
502     if (ratio_ == 0) {
503       (*total_processing_times)[long_name()] = self_processing_time;
504       return;
505     }
506     double inputs_processing_time =
507         ratio_ * TotalProcessingTimeForInputs(*total_processing_times);
508     (*total_processing_times)[long_name()] =
509         self_processing_time + inputs_processing_time;
510   }
511 
ToProto(ModelProto::Node * node_proto) const512   Status ToProto(ModelProto::Node* node_proto) const {
513     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
514     node_proto->set_node_class(NodeClass::KNOWN_RATIO);
515     node_proto->set_ratio(ratio_);
516     return Status::OK();
517   }
518 
519  private:
520   const double ratio_;
521 };
522 
523 class AsyncKnownRatio : public Node {
524  public:
AsyncKnownRatio(Node::Args args,double ratio,double memory_ratio,std::vector<std::shared_ptr<Parameter>> parameters)525   AsyncKnownRatio(Node::Args args, double ratio, double memory_ratio,
526                   std::vector<std::shared_ptr<Parameter>> parameters)
527       : Node(args), ratio_(ratio), memory_ratio_(memory_ratio) {
528     for (auto& parameter : parameters) {
529       parameters_[parameter->name] = std::move(parameter);
530     }
531   }
532 
~AsyncKnownRatio()533   virtual ~AsyncKnownRatio() {}
534 
535  protected:
Clone(std::shared_ptr<Node> output) const536   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
537       TF_SHARED_LOCKS_REQUIRED(mu_) {
538     std::vector<std::shared_ptr<Parameter>> parameters;
539     for (auto& pair : parameters_) {
540       parameters.push_back(pair.second);
541     }
542     return std::make_shared<AsyncKnownRatio>(
543         Args{id_, name_, std::move(output)}, ratio_, memory_ratio_, parameters);
544   }
545 
546   // The input time is the sum of inherited input time and parallelism adjusted
547   // self processing time, divided by `ratio_`.
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const548   void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
549       const override TF_SHARED_LOCKS_REQUIRED(mu_) {
550     double inherited_input_time;
551     if (output_) {
552       inherited_input_time = (*input_times)[output_->long_name()];
553     } else {
554       inherited_input_time = (*input_times)[kModelInputTimeKey];
555     }
556     double parallelism = 1.0;
557     auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
558     if (parallelism_parameter) {
559       parallelism = (*parallelism_parameter)->value;
560     }
561 
562     if (ratio_ == 0.0) {
563       (*input_times)[long_name()] =
564           inherited_input_time + SelfProcessingTimeLocked() / parallelism;
565       return;
566     }
567     double input_time =
568         (inherited_input_time + SelfProcessingTimeLocked() / parallelism) /
569         ratio_;
570     (*input_times)[long_name()] = input_time;
571   }
572 
573   // The output time is the sum of parallelism adjusted self processing time and
574   // expected wait time from the buffer model estimated using
575   // `ComputeWaitTime(producer_time, consumer_time, parallelism, ...)`, where
576   // `producer_time` is the product of `ratio_` and the sum of output times of
577   // inputs, `consumer_time` is the product of `ratio_` and the `input_time`
578   // specified through `input_times` (since for each element stored in the
579   // buffer, the inputs need to be called `ratio_` times), and if the node has
580   // parallelism parameter, then `buffer_size` is derived from `parallelism`.
581   //
582   // Current implementation assumes that there is at most 1 parameter per node.
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const583   void OutputTimeLocked(
584       const absl::flat_hash_map<string, double>& input_times,
585       absl::flat_hash_map<string, double>* gradients,
586       absl::flat_hash_map<string, double>* output_times,
587       absl::flat_hash_map<string, double>* output_time_gradients) const override
588       TF_SHARED_LOCKS_REQUIRED(mu_) {
589     double parallelism = 1.0;
590     double buffer_size = 0.0;
591     auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
592     auto* buffer_size_parameter = gtl::FindOrNull(parameters_, kBufferSize);
593     if (parallelism_parameter) {
594       parallelism = (*parallelism_parameter)->value;
595       if (ratio_ == 0) {
596         buffer_size = parallelism;
597       } else {
598         // Currently, MapAndBatch is the only transformation creates
599         // AsyncKnownRatio nodes with ratio >= 1. For MapAndBatch, we create
600         // `parallelism` threads to apply the function on elements from input
601         // dataset, while one element in the buffer actually corresponds to
602         // `ratio_` elements from input dataset. So we adjust the `buffer_size`
603         // by dividing `ratio_`.
604         buffer_size = parallelism / ratio_;
605       }
606     } else if (buffer_size_parameter) {
607       buffer_size = (*buffer_size_parameter)->value;
608     }
609     double self_processing_time = SelfProcessingTimeLocked();
610     double output_time, wait_time, consumer_time, producer_time;
611     double input_time = input_times.at(long_name());
612 
613     if (ratio_ == 0) {
614       consumer_time = input_time;
615       producer_time = 0.0L;
616       if (gradients) {
617         for (const auto& node :
618              CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
619           gradients->erase(node->long_name());
620         }
621 
622         double producer_time_der = 0.0L;
623         double consumer_time_der = 0.0L;
624         double buffer_size_der = 0.0L;
625         wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
626                                     &producer_time_der, &consumer_time_der,
627                                     &buffer_size_der);
628         (*output_time_gradients)[long_name()] = consumer_time_der;
629         if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
630           (*gradients)[long_name()] = -(1.0L + consumer_time_der) *
631                                           self_processing_time /
632                                           Square(parallelism) +
633                                       buffer_size_der;
634         } else if (buffer_size_parameter &&
635                    (*buffer_size_parameter)->state->tunable) {
636           (*gradients)[long_name()] = buffer_size_der;
637         }
638       } else {
639         wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
640                                     /*producer_time_derivative=*/nullptr,
641                                     /*consumer_time_derivative=*/nullptr,
642                                     /*buffer_size_derivative=*/nullptr);
643       }
644       output_time = self_processing_time / parallelism + wait_time;
645       (*output_times)[long_name()] = output_time;
646       return;
647     }
648 
649     consumer_time = input_time * ratio_;
650     producer_time = ratio_ * OutputTimeForInputs(*output_times);
651     if (gradients) {
652       double producer_time_der = 0.0L;
653       double consumer_time_der = 0.0L;
654       double buffer_size_der = 0.0L;
655       wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
656                                   &producer_time_der, &consumer_time_der,
657                                   &buffer_size_der);
658       double inputs_time_der_sum =
659           OutputTimeGradientsForInputs(*output_time_gradients);
660       (*output_time_gradients)[long_name()] =
661           consumer_time_der + producer_time_der * inputs_time_der_sum;
662 
663       for (const auto& node :
664            CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
665         auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
666         if (gradient) {
667           *gradient *= (ratio_ * producer_time_der);
668         }
669       }
670 
671       // Add derivative w.r.t. own parameter if it's tunable.
672       if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
673         (*gradients)[long_name()] = buffer_size_der / ratio_ -
674                                     (1.0L + consumer_time_der +
675                                      producer_time_der * inputs_time_der_sum) *
676                                         self_processing_time /
677                                         Square(parallelism);
678       } else if (buffer_size_parameter &&
679                  (*buffer_size_parameter)->state->tunable) {
680         (*gradients)[long_name()] = buffer_size_der;
681       }
682     } else {
683       wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
684                                   /*producer_time_derivative=*/nullptr,
685                                   /*consumer_time_derivative=*/nullptr,
686                                   /*buffer_size_derivative=*/nullptr);
687     }
688     output_time = self_processing_time / parallelism + wait_time;
689     (*output_times)[long_name()] = output_time;
690   }
691 
692   // The processing time is the sum of the self processing time and the product
693   // of `ratio_` and the sum of processing times of inputs.
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)694   void TotalProcessingTimeLocked(
695       absl::flat_hash_map<string, double>* processing_times,
696       absl::flat_hash_map<string, double>* total_processing_times) override
697       TF_SHARED_LOCKS_REQUIRED(mu_) {
698     double self_processing_time = SelfProcessingTimeLocked();
699     if (processing_times) {
700       (*processing_times)[long_name()] = self_processing_time;
701     }
702     if (ratio_ == 0) {
703       (*total_processing_times)[long_name()] = self_processing_time;
704       return;
705     }
706     double inputs_processing_time =
707         ratio_ * TotalProcessingTimeForInputs(*total_processing_times);
708     (*total_processing_times)[long_name()] =
709         self_processing_time + inputs_processing_time;
710   }
711 
MaximumBufferedBytes() const712   double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
713     double result = 0;
714     auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
715     if (!parameter) {
716       parameter = gtl::FindOrNull(parameters_, kParallelism);
717     }
718 
719     if (parameter) {
720       if (memory_ratio_ == 0) {
721         result += (*parameter)->value * AverageBufferedElementSize();
722       } else {
723         // The estimation is currently not accurate for MapAndBatchDataset for
724         // the maximum buffer size does not match `num_parallel_calls`
725         // parameter.
726         result +=
727             (*parameter)->value * AverageBufferedElementSize() / memory_ratio_;
728       }
729     }
730     return result;
731   }
732 
ToProto(ModelProto::Node * node_proto) const733   Status ToProto(ModelProto::Node* node_proto) const {
734     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
735     node_proto->set_node_class(NodeClass::ASYNC_KNOWN_RATIO);
736     node_proto->set_ratio(ratio_);
737     node_proto->set_memory_ratio(memory_ratio_);
738     return Status::OK();
739   }
740 
741  private:
742   // Identifies how many input elements need to be created to construct an
743   // element for the dataset.
744   //
745   // Currently the value is 1 for PrefetchDataset and ParallelMapDataset,
746   // batch_size for MapAndBatchDataset and ParallelBatchDataset.
747   const double ratio_;
748   // For parallelism nodes, identifies how many parallelism calls are introduced
749   // by one buffered element. The value is defined to correctly estimate RAM
750   // budget bound with given num_parallel_calls (or buffer_size) combined with
751   // the estimated average size of buffered elements.
752   const double memory_ratio_;
753 };
754 
755 class UnknownRatio : public Node {
756  public:
757   using Node::Node;
758 
~UnknownRatio()759   virtual ~UnknownRatio() {}
760 
761  protected:
Clone(std::shared_ptr<Node> output) const762   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
763       TF_SHARED_LOCKS_REQUIRED(mu_) {
764     return std::make_shared<UnknownRatio>(Args{id_, name_, std::move(output)});
765   }
766 
767   // The input time is the sum of inherited input time and self processing time,
768   // divided by the ratio estimate.
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const769   void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
770       const override TF_SHARED_LOCKS_REQUIRED(mu_) {
771     double inherited_input_time;
772     if (output_) {
773       inherited_input_time = (*input_times)[output_->long_name()];
774     } else {
775       inherited_input_time = (*input_times)[kModelInputTimeKey];
776     }
777 
778     if (num_elements_ == 0 || inputs_.empty() ||
779         inputs_.front()->num_elements() == 0) {
780       (*input_times)[long_name()] = inherited_input_time;
781       return;
782     }
783     std::shared_ptr<Node> input = inputs_.front();
784     double ratio = static_cast<double>(input->num_elements()) /
785                    static_cast<double>(num_elements_);
786     double input_time =
787         (inherited_input_time + SelfProcessingTimeLocked()) / ratio;
788     (*input_times)[long_name()] = input_time;
789   }
790 
791   // The output time is the sum of the self processing time and the product of
792   // the ratio estimate and the sum of output times of inputs.
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const793   void OutputTimeLocked(
794       const absl::flat_hash_map<string, double>& input_times,
795       absl::flat_hash_map<string, double>* gradients,
796       absl::flat_hash_map<string, double>* output_times,
797       absl::flat_hash_map<string, double>* output_time_gradients) const override
798       TF_SHARED_LOCKS_REQUIRED(mu_) {
799     double self_processing_time = SelfProcessingTimeLocked();
800     if (num_elements_ == 0 || inputs_.empty() ||
801         inputs_.front()->num_elements() == 0) {
802       (*output_times)[long_name()] = self_processing_time;
803       if (gradients) {
804         for (const auto& node :
805              CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
806           gradients->erase(node->long_name());
807         }
808       }
809       return;
810     }
811     // TODO(jsimsa): The current implementation assumes that the number of input
812     // elements consumed per output is the same across all inputs.
813     double ratio = static_cast<double>(inputs_.front()->num_elements()) /
814                    static_cast<double>(num_elements_);
815     if (gradients) {
816       for (const auto& node :
817            CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
818         auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
819         if (gradient) {
820           *gradient *= ratio;
821         }
822       }
823       (*output_time_gradients)[long_name()] =
824           OutputTimeGradientsForInputs(*output_time_gradients);
825     }
826     double inputs_output_time = ratio * OutputTimeForInputs(*output_times);
827     (*output_times)[long_name()] = self_processing_time + inputs_output_time;
828   }
829 
830   // The processing time is the sum of the self processing time and the product
831   // of the ratio estimate and the sum of processing times of inputs.
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)832   void TotalProcessingTimeLocked(
833       absl::flat_hash_map<string, double>* processing_times,
834       absl::flat_hash_map<string, double>* total_processing_times) override
835       TF_SHARED_LOCKS_REQUIRED(mu_) {
836     double self_processing_time = SelfProcessingTimeLocked();
837     if (processing_times) {
838       (*processing_times)[long_name()] = self_processing_time;
839     }
840     if (inputs_.empty() || num_elements_ == 0) {
841       (*total_processing_times)[long_name()] = self_processing_time;
842       return;
843     }
844     // TODO(jsimsa): The current implementation assumes that the number of input
845     // elements consumed per output is the same across all inputs.
846     std::shared_ptr<Node> input = inputs_.front();
847     double ratio = static_cast<double>(input->num_elements()) /
848                    static_cast<double>(num_elements_);
849     double inputs_processing_time =
850         ratio * TotalProcessingTimeForInputs(*total_processing_times);
851     (*total_processing_times)[long_name()] =
852         self_processing_time + inputs_processing_time;
853   }
854 
ToProto(ModelProto::Node * node_proto) const855   Status ToProto(ModelProto::Node* node_proto) const {
856     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
857     node_proto->set_node_class(NodeClass::UNKNOWN_RATIO);
858     return Status::OK();
859   }
860 };
861 
862 class Unknown : public Node {
863  public:
864   using Node::Node;
865 
~Unknown()866   virtual ~Unknown() {}
867 
868  protected:
Clone(std::shared_ptr<Node> output) const869   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
870       TF_SHARED_LOCKS_REQUIRED(mu_) {
871     return std::make_shared<Unknown>(Args{id_, name_, std::move(output)});
872   }
873 
874   // The input time is the inherited input time.
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const875   void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
876       const override TF_SHARED_LOCKS_REQUIRED(mu_) {
877     double inherited_input_time;
878     if (output_) {
879       inherited_input_time = (*input_times)[output_->long_name()];
880     } else {
881       inherited_input_time = (*input_times)[kModelInputTimeKey];
882     }
883     (*input_times)[long_name()] = inherited_input_time;
884   }
885 
886   // The output time is the sum of output times of inputs.
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const887   void OutputTimeLocked(
888       const absl::flat_hash_map<string, double>& input_times,
889       absl::flat_hash_map<string, double>* gradients,
890       absl::flat_hash_map<string, double>* output_times,
891       absl::flat_hash_map<string, double>* output_time_gradients) const override
892       TF_SHARED_LOCKS_REQUIRED(mu_) {
893     (*output_times)[long_name()] = OutputTimeForInputs(*output_times);
894     if (gradients) {
895       (*output_time_gradients)[long_name()] =
896           OutputTimeGradientsForInputs(*output_time_gradients);
897     }
898   }
899 
900   // The processing time is the sum of processing times of inputs.
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)901   void TotalProcessingTimeLocked(
902       absl::flat_hash_map<string, double>* processing_times,
903       absl::flat_hash_map<string, double>* total_processing_times) override
904       TF_SHARED_LOCKS_REQUIRED(mu_) {
905     if (processing_times) {
906       (*processing_times)[long_name()] = SelfProcessingTimeLocked();
907     }
908     (*total_processing_times)[long_name()] =
909         TotalProcessingTimeForInputs(*total_processing_times);
910   }
911 
ToProto(ModelProto::Node * node_proto) const912   Status ToProto(ModelProto::Node* node_proto) const {
913     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
914     node_proto->set_node_class(NodeClass::UNKNOWN);
915     return Status::OK();
916   }
917 };
918 
919 }  // namespace
920 
921 thread_local int64 Node::work_start_;
922 
MakeParameter(const string & name,std::shared_ptr<SharedState> state,double min,double max)923 std::shared_ptr<Parameter> MakeParameter(const string& name,
924                                          std::shared_ptr<SharedState> state,
925                                          double min, double max) {
926   return std::make_shared<Parameter>(name, state, min, max);
927 }
928 
MakeInterleaveManyNode(Node::Args args)929 std::shared_ptr<Node> MakeInterleaveManyNode(Node::Args args) {
930   return std::make_shared<InterleaveMany>(std::move(args));
931 }
932 
MakeAsyncInterleaveManyNode(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)933 std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
934     Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters) {
935   return std::make_shared<AsyncInterleaveMany>(std::move(args),
936                                                std::move(parameters));
937 }
938 
MakeKnownRatioNode(Node::Args args,double ratio)939 std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio) {
940   return std::make_shared<KnownRatio>(std::move(args), ratio);
941 }
942 
MakeAsyncKnownRatioNode(Node::Args args,double ratio,double memory_ratio,std::vector<std::shared_ptr<Parameter>> parameters)943 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
944     Node::Args args, double ratio, double memory_ratio,
945     std::vector<std::shared_ptr<Parameter>> parameters) {
946   return std::make_shared<AsyncKnownRatio>(std::move(args), ratio, memory_ratio,
947                                            std::move(parameters));
948 }
949 
MakeAsyncKnownRatioNode(Node::Args args,double ratio,std::vector<std::shared_ptr<Parameter>> parameters)950 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
951     Node::Args args, double ratio,
952     std::vector<std::shared_ptr<Parameter>> parameters) {
953   return MakeAsyncKnownRatioNode(std::move(args), /*ratio=*/ratio,
954                                  /*memory_ratio=*/ratio, std::move(parameters));
955 }
956 
MakeSourceNode(Node::Args args)957 std::shared_ptr<Node> MakeSourceNode(Node::Args args) {
958   return MakeKnownRatioNode(std::move(args), 0);
959 }
960 
MakeUnknownRatioNode(Node::Args args)961 std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args) {
962   return std::make_shared<UnknownRatio>(std::move(args));
963 }
964 
MakeUnknownNode(Node::Args args)965 std::shared_ptr<Node> MakeUnknownNode(Node::Args args) {
966   return std::make_shared<Unknown>(std::move(args));
967 }
968 
ComputeWaitTime(const double & producer_time,const double & consumer_time,const double & buffer_size,double * producer_time_derivative,double * consumer_time_derivative,double * buffer_size_derivative)969 double Node::ComputeWaitTime(const double& producer_time,
970                              const double& consumer_time,
971                              const double& buffer_size,
972                              double* producer_time_derivative,
973                              double* consumer_time_derivative,
974                              double* buffer_size_derivative) {
975   // If we set x=`consumer_time`, y=`producer_time`, n=`buffer_size`,
976   // p=`p_buffer_empty`, T=`wait_time`, then we have:
977   // if y = 0, then p = 0;
978   // elif x = 0, then p = 1;
979   // elif x = y, then p = 1 / (n+1);
980   // else p = [1 - x/y] / [1 - power(x/y, n+1)].
981   //
982   // We also have T = p * y, and derivatives of T w.r.t. x, y, n are computed:
983   // dT/dx = dp/dx * y,
984   // dT/dy = p + dp/dy * y,
985   // dT/dn = dp/dn * y.
986   // Then the remaining work is to compute dp/dx, dp/dy, dp/dn by considering
987   // different cases and substitute the values into above formulas.
988 
989   // Case 1: if producer is infinitely fast. The buffer will always be full.
990   // Wait time will always be 0.
991   if (producer_time == 0) {
992     if (producer_time_derivative) {
993       // Note a common error is `*producer_time_derivative = 0` since p=0 on the
994       // line y=0 doesn't imply dp/dy = 0 there. Actually to compute dp/dy at
995       // (x,0), we need to consider lim_{dy->0+} [p(x,dy)-p(x,0)] / dy, where
996       // p(x,0)=0 and p(x,dy) = [1 - x/dy] / [1 - power(x/dy, n+1)].
997       if (buffer_size == 0 || consumer_time == 0) {
998         *producer_time_derivative = 1.0L;
999       } else {
1000         *producer_time_derivative = 0.0L;
1001       }
1002     }
1003     if (consumer_time_derivative) {
1004       *consumer_time_derivative = 0.0L;
1005     }
1006     if (buffer_size_derivative) {
1007       *buffer_size_derivative = 0.0L;
1008     }
1009     return 0.0L;
1010   }
1011 
1012   // Case 2: if consumer is infinitely fast. Wait time is always the time to
1013   // produce an output.
1014   if (consumer_time == 0) {
1015     if (producer_time_derivative) {
1016       *producer_time_derivative = 1.0L;
1017     }
1018     if (consumer_time_derivative) {
1019       // Note a common error is `*consumer_time_derivative = 0` since p=1 on the
1020       // line x=0 doesn't imply dp/dx = 0 there. Actually to compute dp/dx at
1021       // (0,y), we need to consider lim_{dx->0+} [p(dx,y)-p(0,y)] / dx, where
1022       // p(0,y)=1, p(dx,y) = [1 - dx/y] / [1 - power(dx/y, n+1)] if y!=0.
1023       if (buffer_size == 0) {
1024         *consumer_time_derivative = 0.0L;
1025       } else {
1026         *consumer_time_derivative = -1.0L;
1027       }
1028     }
1029     if (buffer_size_derivative) {
1030       *buffer_size_derivative = 0.0L;
1031     }
1032     return producer_time;
1033   }
1034 
1035   // Case 3: the consumer and the producer are equally fast. Expected wait time
1036   // decreases linearly with the size of the buffer.
1037   if (consumer_time == producer_time) {
1038     const double p_buffer_empty = 1.0L / (buffer_size + 1.0L);
1039     const double p_buffer_empty_der =
1040         -buffer_size / (2.0L * buffer_size + 2.0L);
1041     if (producer_time_derivative) {
1042       // Note a common error is `*producer_time_derivative = p_buffer_empty`
1043       // since p=1/(n+1) on the line x=y doesn't imply dp/dy = 0 there. Actually
1044       // to compute dp/dy at (y,y), we need to consider
1045       // lim_{dy->0} [p(y,y+dy)-p(y,y)] / dy, where p(y,y)=1/(n+1),
1046       // p(y,y+dy) = [1 - y/(y+dy)] / [1 - power(y/(y+dy), n+1)].
1047       *producer_time_derivative = p_buffer_empty - p_buffer_empty_der;
1048     }
1049     if (consumer_time_derivative) {
1050       // Note a common error is `*consumer_time_derivative = 0` since
1051       // p=1/(n+1) on the line x=y doesn't imply dp/dx = 0 there. Actually to
1052       // compute dp/dx at (x,x), we need to consider
1053       // lim_{dx->0} [p(x+dx,x)-p(x,x)] / dx, where p(x,x)=1/(n+1),
1054       // p(x+dx,x) = [1 - (x+dx)/x] / [1 - power((x+dx)/x, n+1)].
1055       *consumer_time_derivative = p_buffer_empty_der;
1056     }
1057     if (buffer_size_derivative) {
1058       *buffer_size_derivative = -producer_time / Square(buffer_size + 1.0L);
1059     }
1060     return p_buffer_empty * producer_time;
1061   }
1062 
1063   // Case 4: the consumer is slower than the producer and neither is infinitely
1064   // fast. Case 4 and Case 5 actually follow same formula. Separate them for
1065   // numerical computation reasons.
1066   if (consumer_time > producer_time) {
1067     const double ratio = producer_time / consumer_time;
1068     const double ratio_pow = std::pow(ratio, buffer_size);
1069     const double p_buffer_empty =
1070         ratio_pow * (1.0L - ratio) / (1.0L - ratio * ratio_pow);
1071     const double p_buffer_empty_der =
1072         (buffer_size - (buffer_size + 1.0L) * ratio + ratio_pow * ratio) *
1073         ratio_pow / ratio / Square(1.0L - ratio_pow * ratio);
1074     if (producer_time_derivative) {
1075       *producer_time_derivative = p_buffer_empty + p_buffer_empty_der * ratio;
1076     }
1077     if (consumer_time_derivative) {
1078       *consumer_time_derivative = -p_buffer_empty_der * Square(ratio);
1079     }
1080     if (buffer_size_derivative) {
1081       *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1082                                 std::log(ratio) * producer_time;
1083     }
1084     return p_buffer_empty * producer_time;
1085   }
1086 
1087   // Case 5: the producer is slower than the consumer and neither is infinitely
1088   // fast.
1089   const double ratio = consumer_time / producer_time;
1090   const double ratio_pow = std::pow(ratio, buffer_size);
1091   const double p_buffer_empty = (1.0L - ratio) / (1.0L - ratio_pow * ratio);
1092   const double p_buffer_empty_der =
1093       ((buffer_size + 1.0L - buffer_size * ratio) * ratio_pow - 1.0L) /
1094       Square(1.0L - ratio_pow * ratio);
1095   if (producer_time_derivative) {
1096     *producer_time_derivative = p_buffer_empty - p_buffer_empty_der * ratio;
1097   }
1098   if (consumer_time_derivative) {
1099     *consumer_time_derivative = p_buffer_empty_der;
1100   }
1101   if (buffer_size_derivative) {
1102     *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1103                               ratio_pow * ratio * std::log(ratio) *
1104                               producer_time;
1105   }
1106   return p_buffer_empty * producer_time;
1107 }
1108 
CollectTunableParameters(absl::flat_hash_map<string,std::shared_ptr<Parameter>> * parameters) const1109 void Node::CollectTunableParameters(
1110     absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const {
1111   tf_shared_lock l(mu_);
1112   // Collect tunable parameters from the leaves of the nodes tree to the root.
1113   for (const auto& node :
1114        CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1115     tf_shared_lock l(node->mu_);
1116     node->CollectTunableParametersHelper(parameters);
1117   }
1118   CollectTunableParametersHelper(parameters);
1119 }
1120 
DebugString() const1121 string Node::DebugString() const {
1122   absl::flat_hash_map<string, string> debug_strings;
1123   tf_shared_lock l(mu_);
1124   // Build up the debug string from the leaves of the nodes tree to the root.
1125   for (const auto& node :
1126        CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1127     tf_shared_lock l(node->mu_);
1128     node->DebugStringHelper(&debug_strings);
1129   }
1130   DebugStringHelper(&debug_strings);
1131 
1132   return debug_strings[long_name()];
1133 }
1134 
FlushMetrics()1135 void Node::FlushMetrics() {
1136   if (!record_metrics_) {
1137     return;
1138   }
1139   metrics_.record_bytes_consumed(bytes_consumed_);
1140   metrics_.record_bytes_produced(bytes_produced_);
1141   metrics_.record_num_elements(num_elements_);
1142 }
1143 
OutputTime(absl::flat_hash_map<string,double> * input_times,absl::flat_hash_map<string,double> * gradients) const1144 double Node::OutputTime(absl::flat_hash_map<string, double>* input_times,
1145                         absl::flat_hash_map<string, double>* gradients) const {
1146   // To store the output time gradient w.r.t. input time (if `gradients` is not
1147   // `nullptr`) and the output time for each node.
1148   absl::flat_hash_map<string, double> output_time_gradients, output_times;
1149   tf_shared_lock l(mu_);
1150   auto nodes = CollectNodes(TraversalOrder::BFS, IsAutotuneNode);
1151 
1152   // Computes and stores input time for each node from the root to leaves of the
1153   // nodes tree.
1154   InputTimeLocked(input_times);
1155   for (const auto& node : nodes) {
1156     tf_shared_lock l(node->mu_);
1157     node->InputTimeLocked(input_times);
1158   }
1159 
1160   std::reverse(nodes.begin(), nodes.end());
1161   // Computes and stores the output time and output time gradient w.r.t. input
1162   // time (if `gradients` is not `nullptr`) for each node from leaves of the
1163   // nodes tree to the root.
1164   for (const auto& node : nodes) {
1165     tf_shared_lock l(node->mu_);
1166     node->OutputTimeLocked(*input_times, gradients, &output_times,
1167                            &output_time_gradients);
1168   }
1169   OutputTimeLocked(*input_times, gradients, &output_times,
1170                    &output_time_gradients);
1171 
1172   return output_times[long_name()];
1173 }
1174 
Snapshot() const1175 std::shared_ptr<Node> Node::Snapshot() const {
1176   NodePairList node_pairs;
1177   auto result = SnapshotHelper(nullptr, &node_pairs);
1178 
1179   while (!node_pairs.empty()) {
1180     auto node_pair = node_pairs.front();
1181     node_pairs.pop_front();
1182     std::shared_ptr<Node> current = node_pair.first,
1183                           cloned_output = node_pair.second;
1184     cloned_output->add_input(
1185         current->SnapshotHelper(cloned_output, &node_pairs));
1186   }
1187   return result;
1188 }
1189 
SelfProcessingTime() const1190 double Node::SelfProcessingTime() const {
1191   tf_shared_lock l(mu_);
1192   return SelfProcessingTimeLocked();
1193 }
1194 
TotalBufferedBytes() const1195 double Node::TotalBufferedBytes() const {
1196   absl::flat_hash_map<string, double> total_bytes;
1197   tf_shared_lock l(mu_);
1198   // Compute total buffered bytes from the leaves of the nodes tree to the root.
1199   for (const auto& node :
1200        CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1201     tf_shared_lock l(node->mu_);
1202     node->TotalBufferedBytesHelper(&total_bytes);
1203   }
1204   TotalBufferedBytesHelper(&total_bytes);
1205 
1206   return total_bytes[long_name()];
1207 }
1208 
TotalMaximumBufferedBytes() const1209 double Node::TotalMaximumBufferedBytes() const {
1210   absl::flat_hash_map<string, double> total_bytes;
1211   tf_shared_lock l(mu_);
1212   // Compute total maximum buffered bytes from the leaves of the nodes tree
1213   // to the root.
1214   for (const auto& node :
1215        CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1216     tf_shared_lock l(node->mu_);
1217     node->TotalMaximumBufferedBytesHelper(&total_bytes);
1218   }
1219   TotalMaximumBufferedBytesHelper(&total_bytes);
1220 
1221   return total_bytes[long_name()];
1222 }
1223 
TotalProcessingTime(absl::flat_hash_map<string,double> * processing_times)1224 double Node::TotalProcessingTime(
1225     absl::flat_hash_map<string, double>* processing_times) {
1226   // Create a hash map to store the per-element CPU time spent in the subtree
1227   // rooted in each node.
1228   absl::flat_hash_map<string, double> total_processing_times;
1229   tf_shared_lock l(mu_);
1230 
1231   // Computes per-element CPU time spent in the subtree rooted in the node from
1232   // the leaves of the nodes tree to the root.
1233   for (const auto& node :
1234        CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1235     tf_shared_lock l(node->mu_);
1236     node->TotalProcessingTimeLocked(processing_times, &total_processing_times);
1237   }
1238   TotalProcessingTimeLocked(processing_times, &total_processing_times);
1239 
1240   return total_processing_times[long_name()];
1241 }
1242 
AverageBufferedElementSize() const1243 double Node::AverageBufferedElementSize() const {
1244   DCHECK_GE(num_elements_, 0);
1245   DCHECK_GE(buffered_elements_, 0);
1246   if (num_elements_ <= 0) {
1247     if (buffered_elements_ <= 0) {
1248       // If there are no produced elements or buffered elements recorded, return
1249       // 0.
1250       return 0;
1251     }
1252     // If there are no produced elements but some buffered elements, return the
1253     // average size of all buffered elements.
1254     return static_cast<double>(buffered_bytes_) /
1255            static_cast<double>(buffered_elements_);
1256   }
1257 
1258   if (buffered_elements_ <= 0) {
1259     // If there are no buffered elements but some produced elements, return the
1260     // average size of all produced elements.
1261     return static_cast<double>(bytes_produced_) /
1262            static_cast<double>(num_elements_);
1263   }
1264 
1265   // Otherwise, return the mean value of average size of all produced elements
1266   // and average size of all buffered elements.
1267   return (static_cast<double>(bytes_produced_) /
1268               static_cast<double>(num_elements_) +
1269           static_cast<double>(buffered_bytes_) /
1270               static_cast<double>(buffered_elements_)) /
1271          2.0;
1272 }
1273 
OutputTimeForInputs(const absl::flat_hash_map<string,double> & output_times) const1274 double Node::OutputTimeForInputs(
1275     const absl::flat_hash_map<string, double>& output_times) const {
1276   double sum = 0;
1277   for (auto& input : inputs_) {
1278     // Inputs for which autotuning is disabled are excluded.
1279     if (input->autotune()) {
1280       sum += output_times.at(input->long_name());
1281     }
1282   }
1283   return sum;
1284 }
1285 
OutputTimeGradientsForInputs(const absl::flat_hash_map<string,double> & output_time_gradients) const1286 double Node::OutputTimeGradientsForInputs(
1287     const absl::flat_hash_map<string, double>& output_time_gradients) const {
1288   double sum = 0;
1289   for (auto& input : inputs_) {
1290     // Inputs for which autotuning is disabled are excluded.
1291     if (input->autotune()) {
1292       sum +=
1293           gtl::FindWithDefault(output_time_gradients, input->long_name(), 0.0L);
1294     }
1295   }
1296   return sum;
1297 }
1298 
TotalProcessingTimeForInputs(const absl::flat_hash_map<string,double> & total_processing_times)1299 double Node::TotalProcessingTimeForInputs(
1300     const absl::flat_hash_map<string, double>& total_processing_times) {
1301   // If the number of elements produced by an input is smaller than this
1302   // constant, then its processing time is estimated using a weighted average
1303   // of the empirical processing time and processing time history.
1304   constexpr int kNumElementsThreshold = 30;
1305 
1306   // Identifies the minimum number of input processing times to collect
1307   // before the processing time history is used as a prior.
1308   constexpr int kCountThreshold = 30;
1309 
1310   double sum = 0;
1311   for (auto& input : inputs_) {
1312     // Inputs for which autotuning is disabled are excluded.
1313     if (input->autotune()) {
1314       double input_processing_time =
1315           total_processing_times.at(input->long_name());
1316       int64 num_elements = input->num_elements();
1317       if (num_elements < kNumElementsThreshold) {
1318         if (input_processing_time_count_ < kCountThreshold) {
1319           sum += input_processing_time;
1320         } else {
1321           // The fewer elements the input has produced so far, the more weight
1322           // is assigned to the prior to reduce volatility.
1323           double prior_weight = 1.0L / static_cast<double>(2 << num_elements);
1324           double prior =
1325               input_processing_time_sum_ / input_processing_time_count_;
1326           sum += (1.0L - prior_weight) * input_processing_time +
1327                  prior_weight * prior;
1328         }
1329       } else {
1330         sum += input_processing_time;
1331         input_processing_time_count_++;
1332         input_processing_time_sum_ += input_processing_time;
1333       }
1334     }
1335   }
1336   return sum;
1337 }
1338 
SelfProcessingTimeLocked() const1339 double Node::SelfProcessingTimeLocked() const {
1340   if (num_elements_ == 0) {
1341     return 0;
1342   }
1343   return static_cast<double>(processing_time_) /
1344          static_cast<double>(num_elements_);
1345 }
1346 
CollectNodes(TraversalOrder order,bool collect_node (const std::shared_ptr<Node>)) const1347 Node::NodeVector Node::CollectNodes(
1348     TraversalOrder order, bool collect_node(const std::shared_ptr<Node>)) const
1349     TF_SHARED_LOCKS_REQUIRED(mu_) {
1350   NodeVector node_vector;
1351   std::list<std::shared_ptr<Node>> temp_list;
1352 
1353   for (auto& input : inputs_) {
1354     if (collect_node(input)) {
1355       node_vector.push_back(input);
1356       temp_list.push_back(input);
1357     }
1358   }
1359 
1360   while (!temp_list.empty()) {
1361     auto cur_node = temp_list.front();
1362     temp_list.pop_front();
1363     tf_shared_lock l(cur_node->mu_);
1364     for (auto& input : cur_node->inputs_) {
1365       if (collect_node(input)) {
1366         node_vector.push_back(input);
1367         temp_list.push_back(input);
1368       }
1369     }
1370   }
1371 
1372   if (order == TraversalOrder::REVERSE_BFS) {
1373     std::reverse(node_vector.begin(), node_vector.end());
1374   }
1375   return node_vector;
1376 }
1377 
CollectTunableParametersHelper(absl::flat_hash_map<string,std::shared_ptr<Parameter>> * parameters) const1378 void Node::CollectTunableParametersHelper(
1379     absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const
1380     TF_SHARED_LOCKS_REQUIRED(mu_) {
1381   // If autotune is turned off or there are no elements recorded, we don't
1382   // collect the parameters on the node.
1383   if (!autotune_ || num_elements_ <= 0) {
1384     return;
1385   }
1386   for (auto& pair : parameters_) {
1387     if (pair.second->state->tunable) {
1388       parameters->insert(std::make_pair(long_name(), pair.second));
1389     }
1390   }
1391 }
1392 
DebugStringHelper(absl::flat_hash_map<string,string> * debug_strings) const1393 void Node::DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings)
1394     const TF_SHARED_LOCKS_REQUIRED(mu_) {
1395   string result;
1396   strings::StrAppend(&result, long_name(), ":\n");
1397   strings::StrAppend(&result, "  autotune=", autotune_.load(), "\n");
1398   strings::StrAppend(&result, "  buffered_bytes=", buffered_bytes_.load(),
1399                      "\n");
1400   strings::StrAppend(&result, "  buffered_elements=", buffered_elements_.load(),
1401                      "\n");
1402   strings::StrAppend(&result, "  bytes_consumed=", bytes_consumed_.load(),
1403                      "\n");
1404   strings::StrAppend(&result, "  bytes_produced=", bytes_produced_.load(),
1405                      "\n");
1406   strings::StrAppend(&result, "  processing_time=", processing_time_.load(),
1407                      "\n");
1408   strings::StrAppend(&result, "  num_elements=", num_elements_.load(), "\n");
1409   string inputs;
1410   for (auto& input : inputs_) {
1411     strings::StrAppend(&inputs, input->long_name(), ",");
1412   }
1413   strings::StrAppend(&result, "  inputs={", inputs, "}\n");
1414   for (auto& input : inputs_) {
1415     strings::StrAppend(&result, debug_strings->at(input->long_name()));
1416   }
1417   debug_strings->insert(std::make_pair(long_name(), result));
1418 }
1419 
SnapshotHelper(std::shared_ptr<Node> cloned_output,Node::NodePairList * node_pairs) const1420 std::shared_ptr<Node> Node::SnapshotHelper(
1421     std::shared_ptr<Node> cloned_output, Node::NodePairList* node_pairs) const {
1422   tf_shared_lock l(mu_);
1423 
1424   // Clone current node(`this`), also set clone of its output node
1425   // (`cloned_output`) to be the output node of the cloned node
1426   // (`cloned_current`).
1427   std::shared_ptr<Node> cloned_current = Clone(cloned_output);
1428   {
1429     cloned_current->autotune_.store(autotune_);
1430     cloned_current->buffered_bytes_.store(buffered_bytes_);
1431     cloned_current->buffered_elements_.store(buffered_elements_);
1432     cloned_current->bytes_consumed_.store(bytes_consumed_);
1433     cloned_current->bytes_produced_.store(bytes_produced_);
1434     cloned_current->num_elements_.store(num_elements_);
1435     cloned_current->record_metrics_.store(false);
1436     cloned_current->processing_time_.store(processing_time_);
1437     mutex_lock l2(cloned_current->mu_);
1438     cloned_current->parameters_ = parameters_;
1439   }
1440 
1441   for (auto& input : inputs_) {
1442     node_pairs->push_back(std::make_pair(input, cloned_current));
1443   }
1444   return cloned_current;
1445 }
1446 
TotalBufferedBytesHelper(absl::flat_hash_map<string,double> * total_bytes) const1447 void Node::TotalBufferedBytesHelper(
1448     absl::flat_hash_map<string, double>* total_bytes) const
1449     TF_SHARED_LOCKS_REQUIRED(mu_) {
1450   if (!autotune_) {
1451     total_bytes->insert(std::make_pair(long_name(), 0));
1452     return;
1453   }
1454 
1455   double result = 0;
1456   auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
1457   if (!parameter) {
1458     parameter = gtl::FindOrNull(parameters_, kParallelism);
1459   }
1460   if (parameter) {
1461     result = buffered_bytes_;
1462   }
1463   for (auto& input : inputs_) {
1464     result += total_bytes->at(input->long_name());
1465   }
1466   total_bytes->insert(std::make_pair(long_name(), result));
1467 }
1468 
TotalMaximumBufferedBytesHelper(absl::flat_hash_map<string,double> * total_bytes) const1469 void Node::TotalMaximumBufferedBytesHelper(
1470     absl::flat_hash_map<string, double>* total_bytes) const
1471     TF_SHARED_LOCKS_REQUIRED(mu_) {
1472   if (!autotune_) {
1473     total_bytes->insert(std::make_pair(long_name(), 0));
1474     return;
1475   }
1476 
1477   double result = MaximumBufferedBytes();
1478   for (auto& input : inputs_) {
1479     result += total_bytes->at(input->long_name());
1480   }
1481   total_bytes->insert(std::make_pair(long_name(), result));
1482 }
1483 
MaximumBufferedBytes() const1484 double Node::MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
1485   return 0;
1486 }
1487 
ToProto(ModelProto::Node * node_proto) const1488 Status Node::ToProto(ModelProto::Node* node_proto) const {
1489   tf_shared_lock l(mu_);
1490   node_proto->set_id(id_);
1491   node_proto->set_name(name_);
1492   node_proto->set_autotune(autotune_);
1493   node_proto->set_buffered_bytes(buffered_bytes_);
1494   node_proto->set_buffered_elements(buffered_elements_);
1495   node_proto->set_bytes_consumed(bytes_consumed_);
1496   node_proto->set_bytes_produced(bytes_produced_);
1497   node_proto->set_num_elements(num_elements_);
1498   node_proto->set_processing_time(processing_time_);
1499   node_proto->set_record_metrics(record_metrics_);
1500 
1501   // Produce protos for all parameters.
1502   for (auto const& parameter : parameters_) {
1503     ModelProto::Node::Parameter* parameter_proto = node_proto->add_parameters();
1504     parameter_proto->set_name(parameter.first);
1505     parameter_proto->set_value(parameter.second->value);
1506     parameter_proto->set_min(parameter.second->min);
1507     parameter_proto->set_max(parameter.second->max);
1508     parameter_proto->set_state_value(parameter.second->state->value);
1509     parameter_proto->set_tunable(parameter.second->state->tunable);
1510   }
1511 
1512   // Produce protos for all inputs.
1513   for (auto const& input : inputs_) {
1514     ModelProto::Node* input_proto = node_proto->add_inputs();
1515     TF_RETURN_IF_ERROR(input->ToProto(input_proto));
1516   }
1517   return Status::OK();
1518 }
1519 
FromProtoHelper(ModelProto::Node node_proto,std::shared_ptr<Node> node)1520 Status Node::FromProtoHelper(ModelProto::Node node_proto,
1521                              std::shared_ptr<Node> node) {
1522   tf_shared_lock l(node->mu_);
1523   node->autotune_.store(node_proto.autotune());
1524   node->buffered_bytes_.store(node_proto.buffered_bytes());
1525   node->buffered_elements_.store(node_proto.buffered_elements());
1526   node->bytes_consumed_.store(node_proto.bytes_consumed());
1527   node->bytes_produced_.store(node_proto.bytes_produced());
1528   node->num_elements_.store(node_proto.num_elements());
1529   node->processing_time_.store(node_proto.processing_time());
1530   node->record_metrics_.store(node_proto.record_metrics());
1531 
1532   // Restore parameters.
1533   int64 num_parameters = node_proto.parameters_size();
1534   for (int i = 0; i < num_parameters; i++) {
1535     const ModelProto::Node::Parameter& parameter_proto =
1536         node_proto.parameters(i);
1537     std::shared_ptr<SharedState> state;
1538     if (parameter_proto.tunable()) {
1539       state =
1540           std::make_shared<SharedState>(kAutotune, std::make_shared<mutex>(),
1541                                         std::make_shared<condition_variable>());
1542       state->value = parameter_proto.state_value();
1543     } else {
1544       state = std::make_shared<SharedState>(
1545           parameter_proto.state_value(), std::make_shared<mutex>(),
1546           std::make_shared<condition_variable>());
1547     }
1548     node->parameters_[parameter_proto.name()] =
1549         MakeParameter(parameter_proto.name(), state, parameter_proto.min(),
1550                       parameter_proto.max());
1551   }
1552   return Status::OK();
1553 }
1554 
FromProto(ModelProto::Node node_proto,std::shared_ptr<Node> output,std::shared_ptr<Node> * node)1555 Status Node::FromProto(ModelProto::Node node_proto,
1556                        std::shared_ptr<Node> output,
1557                        std::shared_ptr<Node>* node) {
1558   // Note that parameters are restored in `FromProtoHelper`.
1559   Args args = {node_proto.id(), node_proto.name(), std::move(output)};
1560   std::shared_ptr<Node> restored_node;
1561   switch (node_proto.node_class()) {
1562     case NodeClass::INTERLEAVE_MANY:
1563       restored_node = std::make_shared<InterleaveMany>(args);
1564       break;
1565     case NodeClass::ASYNC_INTERLEAVE_MANY:
1566       restored_node = std::make_shared<AsyncInterleaveMany>(
1567           args, /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
1568       break;
1569     case NodeClass::KNOWN_RATIO:
1570       restored_node = std::make_shared<KnownRatio>(args, node_proto.ratio());
1571       break;
1572     case NodeClass::ASYNC_KNOWN_RATIO:
1573       restored_node = std::make_shared<AsyncKnownRatio>(
1574           args, node_proto.ratio(), node_proto.memory_ratio(),
1575           /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
1576       break;
1577     case NodeClass::UNKNOWN_RATIO:
1578       restored_node = std::make_shared<UnknownRatio>(args);
1579       break;
1580     default:
1581       restored_node = std::make_shared<Unknown>(args);
1582   }
1583   TF_RETURN_IF_ERROR(FromProtoHelper(node_proto, restored_node));
1584 
1585   // Restore all input nodes as well.
1586   int64 num_inputs = node_proto.inputs_size();
1587   for (int64 i = 0; i < num_inputs; i++) {
1588     const ModelProto::Node& input_proto = node_proto.inputs(i);
1589     std::shared_ptr<Node> input;
1590     TF_RETURN_IF_ERROR(FromProto(input_proto, restored_node, &input));
1591     restored_node->add_input(input);
1592   }
1593   (*node) = std::move(restored_node);
1594   return Status::OK();
1595 }
1596 
AddNode(Node::Factory factory,const string & name,std::shared_ptr<Node> parent,std::shared_ptr<Node> * out_node)1597 void Model::AddNode(Node::Factory factory, const string& name,
1598                     std::shared_ptr<Node> parent,
1599                     std::shared_ptr<Node>* out_node) {
1600   // The name captures the sequence of iterators joined by `::`. We only use the
1601   // last element of the sequence as the name node.
1602   auto node_name = str_util::Split(name, ':', str_util::SkipEmpty()).back();
1603   mutex_lock l(mu_);
1604   std::shared_ptr<Node> node = factory({id_counter_++, node_name, parent});
1605   if (!output_) {
1606     output_ = node;
1607   }
1608   if (parent) {
1609     VLOG(3) << "Adding " << node->long_name() << " as input for "
1610             << parent->long_name();
1611     parent->add_input(node);
1612   } else {
1613     VLOG(3) << "Adding " << node->long_name();
1614   }
1615   collect_resource_usage_ =
1616       collect_resource_usage_ || node->has_tunable_parameters();
1617   *out_node = std::move(node);
1618   // TODO(jsimsa): Reset the optimization period when a node is added so that
1619   // autotuning adapts to changes to the input pipeline faster. Initial attempt
1620   // to enable this functionality caused a regression (see b/179812091).
1621 }
1622 
FlushMetrics()1623 void Model::FlushMetrics() {
1624   std::deque<std::shared_ptr<Node>> queue;
1625   {
1626     tf_shared_lock l(mu_);
1627     if (output_) queue.push_back(output_);
1628   }
1629   while (!queue.empty()) {
1630     auto node = queue.front();
1631     queue.pop_front();
1632     node->FlushMetrics();
1633     for (auto input : node->inputs()) {
1634       queue.push_back(input);
1635     }
1636   }
1637 }
1638 
Optimize(AutotuneAlgorithm algorithm,int64 cpu_budget,int64 ram_budget,double model_input_time)1639 void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
1640                      int64 ram_budget, double model_input_time) {
1641   std::shared_ptr<Node> snapshot;
1642   {
1643     tf_shared_lock lock(mu_);
1644     snapshot = output_->Snapshot();
1645   }
1646   OptimizationParams optimization_params;
1647   optimization_params.set_algorithm(algorithm);
1648   optimization_params.set_cpu_budget(cpu_budget);
1649   optimization_params.set_ram_budget(ram_budget);
1650   optimization_params.set_model_input_time(model_input_time);
1651   switch (algorithm) {
1652     case AutotuneAlgorithm::HILL_CLIMB:
1653       OptimizeHillClimb(snapshot, optimization_params);
1654       break;
1655     case AutotuneAlgorithm::GRADIENT_DESCENT:
1656       OptimizeGradientDescent(snapshot, optimization_params);
1657       break;
1658     default:
1659       VLOG(2) << "Autotuning algorithm was not recognized. Aborting "
1660                  "optimization.";
1661       return;
1662   }
1663   if (!save_dir_.empty()) {
1664     mutex_lock lock(mu_);
1665     Status status = EnsureSaveLoopThreadStarted();
1666     if (status.ok() && save_buffer_.size() < kMaxNumBufferedOptimizeArgs) {
1667       save_buffer_.push_back(std::make_pair(snapshot, optimization_params));
1668       save_cond_var_.notify_all();
1669     } else if (save_buffer_.size() >= kMaxNumBufferedOptimizeArgs) {
1670       VLOG(3) << "Saved snapshots buffer is full. Current snapshot and "
1671                  "optimization parameters will not be saved.";
1672     }
1673   }
1674 }
1675 
RemoveNode(std::shared_ptr<Node> node)1676 void Model::RemoveNode(std::shared_ptr<Node> node) {
1677   mutex_lock l(mu_);
1678   if (node) {
1679     if (node->output()) {
1680       node->output()->remove_input(node);
1681     }
1682     VLOG(3) << "Removing " << node->long_name();
1683   }
1684 }
1685 
1686 absl::flat_hash_map<string, std::shared_ptr<Parameter>>
CollectTunableParameters(std::shared_ptr<Node> node)1687 Model::CollectTunableParameters(std::shared_ptr<Node> node) {
1688   absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters;
1689   node->CollectTunableParameters(&parameters);
1690   return parameters;
1691 }
1692 
ShouldStop(int64 cpu_budget,int64 ram_budget,const absl::flat_hash_map<string,std::shared_ptr<Parameter>> & parameters,const absl::flat_hash_map<string,std::shared_ptr<Parameter>> & parallelism_parameters,const absl::flat_hash_map<string,std::shared_ptr<Parameter>> & buffer_size_parameters,std::shared_ptr<Node> snapshot,bool * cpu_budget_reached)1693 bool Model::ShouldStop(
1694     int64 cpu_budget, int64 ram_budget,
1695     const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
1696     const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
1697         parallelism_parameters,
1698     const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
1699         buffer_size_parameters,
1700     std::shared_ptr<Node> snapshot, bool* cpu_budget_reached) {
1701   if (!(*cpu_budget_reached)) {
1702     // If those essential transformations' parallelism reaches the CPU
1703     // budget, we will only tune the buffer size parameters in future
1704     // iterations.
1705     int64 model_parallelism = 0;
1706     for (auto& pair : parallelism_parameters) {
1707       model_parallelism += std::round(pair.second->value);
1708     }
1709     *cpu_budget_reached = (model_parallelism > cpu_budget);
1710   }
1711 
1712   bool all_max = true;
1713   for (auto& pair :
1714        (*cpu_budget_reached ? buffer_size_parameters : parameters)) {
1715     if (std::round(pair.second->value) < pair.second->max) {
1716       all_max = false;
1717       break;
1718     }
1719   }
1720 
1721   // If all parameters have reached their maximum values or RAM budget is
1722   // reached, we stop the iterations.
1723   return all_max || TotalMaximumBufferedBytes(snapshot) > ram_budget;
1724 }
1725 
1726 // TODO(jsimsa): Add support for tracking and using the model input time.
OptimizeLoop(AutotuneAlgorithm algorithm,int64 cpu_budget,int64 ram_budget,CancellationManager * cancellation_manager)1727 Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, int64 cpu_budget,
1728                            int64 ram_budget,
1729                            CancellationManager* cancellation_manager) {
1730   std::function<void()> unused;
1731   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
1732       cancellation_manager,
1733       [this]() {
1734         mutex_lock l(mu_);
1735         optimize_cond_var_.notify_all();
1736       },
1737       /*deregister_fn=*/&unused));
1738 
1739   int64 last_optimization_ms = 0;
1740   int64 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1741   while (true) {
1742     {
1743       mutex_lock l(mu_);
1744       while (!cancellation_manager->IsCancelled() &&
1745              last_optimization_ms + optimization_period_ms_ > current_time_ms) {
1746         auto wait_ms =
1747             last_optimization_ms + optimization_period_ms_ - current_time_ms;
1748         VLOG(2) << "Waiting for " << wait_ms << " ms.";
1749         optimize_cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
1750         current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1751       }
1752       if (cancellation_manager->IsCancelled()) {
1753         return Status::OK();
1754       }
1755     }
1756 
1757     int64 optimization_start_us = EnvTime::NowMicros();
1758     Optimize(algorithm, cpu_budget, ram_budget, /*model_input_time=*/0);
1759     VLOG(2) << "Optimized for "
1760             << (EnvTime::NowMicros() - optimization_start_us) << " us.";
1761 
1762     // Exponentially increase the period of running the optimization
1763     // until a threshold is reached.
1764     {
1765       mutex_lock l(mu_);
1766       optimization_period_ms_ =
1767           std::min(optimization_period_ms_ << 1, kOptimizationPeriodMaxMs);
1768     }
1769     current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1770     last_optimization_ms = current_time_ms;
1771     FlushMetrics();
1772   }
1773 }
1774 
OptimizeGradientDescent(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params)1775 void Model::OptimizeGradientDescent(
1776     std::shared_ptr<Node> snapshot,
1777     const OptimizationParams& optimization_params) {
1778   VLOG(2) << "Starting optimization of tunable parameters with Gradient "
1779              "Descent.";
1780   auto parameters = CollectTunableParameters(snapshot);
1781   if (parameters.empty()) {
1782     VLOG(2) << "The Gradient Descent optimization is terminated since no node "
1783                "with tunable parameters has recorded elements.";
1784     return;
1785   }
1786 
1787   // The maps of "essential" parallelism parameters and buffer size parameters.
1788   absl::flat_hash_map<string, std::shared_ptr<Parameter>>
1789       parallelism_parameters, buffer_size_parameters;
1790   CollectParameters(snapshot, parameters, &parallelism_parameters,
1791                     &buffer_size_parameters);
1792 
1793   // Initialize the parameter values to minimal before tuning.
1794   for (auto& pair : parameters) {
1795     pair.second->value = pair.second->min;
1796   }
1797 
1798   // Optimization is stopped once the `OutputTime` improvement is smaller than
1799   // this value.
1800   constexpr double kOptimizationPrecision = 100.0L;
1801 
1802   // Maximum number of iterations for optimization.
1803   constexpr int64 kMaxIterations = 1000;
1804 
1805   double output_time = 0;
1806   double new_output_time;
1807 
1808   // When the CPU budget is reached, the parallelism parameter values are fixed
1809   // and we only increase the buffer size parameters.
1810   bool cpu_budget_reached = false;
1811 
1812   for (int i = 0; i < kMaxIterations &&
1813                   !ShouldStop(optimization_params.cpu_budget(),
1814                               optimization_params.ram_budget(), parameters,
1815                               parallelism_parameters, buffer_size_parameters,
1816                               snapshot, &cpu_budget_reached);
1817        ++i) {
1818     absl::flat_hash_map<string, double> gradients;
1819     new_output_time = OutputTime(
1820         snapshot, optimization_params.model_input_time(), &gradients);
1821     // We also terminate once the improvement of the output latency is too
1822     // small.
1823     if (std::abs(output_time - new_output_time) < kOptimizationPrecision) {
1824       break;
1825     }
1826 
1827     UpdateParameterValues(
1828         gradients, &(cpu_budget_reached ? buffer_size_parameters : parameters));
1829     output_time = new_output_time;
1830   }
1831 
1832   for (auto& pair : parameters) {
1833     pair.second->value = std::round(pair.second->value);
1834   }
1835   UpdateStateValues(&parameters);
1836 }
1837 
OptimizeHillClimb(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params)1838 void Model::OptimizeHillClimb(std::shared_ptr<Node> snapshot,
1839                               const OptimizationParams& optimization_params) {
1840   VLOG(2) << "Starting optimization of tunable parameters with Hill Climb.";
1841   const double processing_time = TotalProcessingTime(snapshot);
1842   auto parameters = CollectTunableParameters(snapshot);
1843   if (parameters.empty()) {
1844     VLOG(2) << "The Hill Climb optimization is terminated since no node with "
1845                "tunable parameters has recorded elements.";
1846     return;
1847   }
1848 
1849   // Buffer size parameter will only be incremented if the output latency
1850   // improvement is greater than this constant.
1851   constexpr double kBufferSizeMinDelta = 1.0L;
1852 
1853   // Initialize the parameter values to minimal before tuning.
1854   for (auto& pair : parameters) {
1855     pair.second->value = pair.second->min;
1856   }
1857   while (true) {
1858     const double output_time =
1859         OutputTime(snapshot, optimization_params.model_input_time(),
1860                    /*gradients=*/nullptr);
1861     bool all_max = true;
1862     for (auto& pair : parameters) {
1863       if (pair.second->value < pair.second->max) {
1864         all_max = false;
1865         break;
1866       }
1867     }
1868     if (output_time < processing_time / optimization_params.cpu_budget() ||
1869         all_max ||
1870         TotalMaximumBufferedBytes(snapshot) >
1871             optimization_params.ram_budget()) {
1872       break;
1873     }
1874     double best_delta = -1.0L;
1875     Parameter* best_parameter = nullptr;
1876     for (auto& pair : parameters) {
1877       if (pair.second->value >= pair.second->max) {
1878         continue;
1879       }
1880       pair.second->value++;
1881       double new_output_time =
1882           OutputTime(snapshot, optimization_params.model_input_time(),
1883                      /*gradients=*/nullptr);
1884       double delta = output_time - new_output_time;
1885       if (delta > best_delta &&
1886           (delta > kBufferSizeMinDelta || pair.second->name != kBufferSize)) {
1887         best_delta = delta;
1888         best_parameter = pair.second.get();
1889       }
1890       pair.second->value--;
1891     }
1892     if (!best_parameter) {
1893       VLOG(2) << "Failed to find a tunable parameter that would further "
1894                  "decrease the output time. This means that the autotuning "
1895                  "optimization got stuck in a local maximum. The optimization "
1896                  "attempt will terminate early.";
1897       break;
1898     }
1899     best_parameter->value++;
1900   }
1901   UpdateStateValues(&parameters);
1902 }
1903 
OutputTime(std::shared_ptr<Node> node,double model_input_time,absl::flat_hash_map<string,double> * gradients)1904 double Model::OutputTime(std::shared_ptr<Node> node, double model_input_time,
1905                          absl::flat_hash_map<string, double>* gradients) {
1906   // To store the input time for each node.
1907   absl::flat_hash_map<string, double> input_times = {
1908       {kModelInputTimeKey, model_input_time}};
1909 
1910   // TODO(jsimsa): Now that we are accounting for buffer size in wait time
1911   // computation, assuming that the input is infinitely fast will result in
1912   // inaccurate estimates of the output latency.
1913   //
1914   // We should compute the output latency as a fix-point of the following
1915   // equation: `output_time = node(OutputTime(input_times(1, output_time))`.
1916 
1917   return node->OutputTime(&input_times, gradients);
1918 }
1919 
TotalBufferedBytes(std::shared_ptr<Node> node)1920 double Model::TotalBufferedBytes(std::shared_ptr<Node> node) {
1921   return node->TotalBufferedBytes();
1922 }
1923 
TotalMaximumBufferedBytes(std::shared_ptr<Node> node)1924 double Model::TotalMaximumBufferedBytes(std::shared_ptr<Node> node) {
1925   return node->TotalMaximumBufferedBytes();
1926 }
1927 
TotalProcessingTime(std::shared_ptr<Node> node)1928 double Model::TotalProcessingTime(std::shared_ptr<Node> node) {
1929   return node->TotalProcessingTime(/*processing_times=*/nullptr);
1930 }
1931 
ToProto(ModelProto * model_proto)1932 Status Model::ToProto(ModelProto* model_proto) {
1933   ModelProto::Node* output_proto = model_proto->mutable_output();
1934   tf_shared_lock lock(mu_);
1935   TF_RETURN_IF_ERROR(output_->ToProto(output_proto));
1936   model_proto->set_id_counter(id_counter_);
1937   model_proto->set_collect_resource_usage(collect_resource_usage_);
1938   return Status::OK();
1939 }
1940 
FromProto(ModelProto model_proto,std::unique_ptr<Model> * model)1941 Status Model::FromProto(ModelProto model_proto, std::unique_ptr<Model>* model) {
1942   std::unique_ptr<Model> restored_model = std::make_unique<Model>();
1943   std::shared_ptr<Node> output;
1944   TF_RETURN_IF_ERROR(
1945       Node::FromProto(model_proto.output(), /*output=*/nullptr, &output));
1946   mutex_lock lock(restored_model->mu_);
1947   restored_model->output_ = output;
1948   restored_model->id_counter_ = model_proto.id_counter();
1949   restored_model->collect_resource_usage_.store(
1950       model_proto.collect_resource_usage());
1951   *model = std::move(restored_model);
1952   return Status::OK();
1953 }
1954 
Save(const string & fname,std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params)1955 Status Model::Save(const string& fname, std::shared_ptr<Node> snapshot,
1956                    const OptimizationParams& optimization_params) {
1957   ModelProto model_proto;
1958   std::unique_ptr<Model> model_snapshot = std::make_unique<Model>();
1959   {
1960     mutex_lock lock(model_snapshot->mu_);
1961     model_snapshot->output_ = std::move(snapshot);
1962     model_snapshot->id_counter_ = id_counter_;
1963     model_snapshot->collect_resource_usage_.store(collect_resource_usage_);
1964   }
1965   TF_RETURN_IF_ERROR(model_snapshot->ToProto(&model_proto));
1966   OptimizationParams* saved_optimization_params =
1967       model_proto.mutable_optimization_params();
1968   *saved_optimization_params = optimization_params;
1969   return WriteBinaryProto(Env::Default(), fname, model_proto);
1970 }
1971 
Load(const string & fname,std::unique_ptr<Model> * model,OptimizationParams * optimization_params)1972 Status Model::Load(const string& fname, std::unique_ptr<Model>* model,
1973                    OptimizationParams* optimization_params) {
1974   ModelProto model_proto;
1975   TF_RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), fname, &model_proto));
1976   TF_RETURN_IF_ERROR(FromProto(model_proto, model));
1977   const OptimizationParams restored_optimization_params =
1978       model_proto.optimization_params();
1979   *optimization_params = restored_optimization_params;
1980   return Status::OK();
1981 }
1982 
EnsureSaveLoopThreadStarted()1983 Status Model::EnsureSaveLoopThreadStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1984   if (!save_thread_) {
1985     save_thread_ = absl::WrapUnique(
1986         Env::Default()->StartThread({}, "tf_data_model_save", [this]() {
1987           Status status = SaveLoop();
1988           if (!status.ok()) {
1989             VLOG(2) << "Model save loop failed: " << status.ToString();
1990           }
1991         }));
1992   }
1993   return Status::OK();
1994 }
1995 
SaveLoop()1996 Status Model::SaveLoop() {
1997   TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(save_dir_));
1998   while (true) {
1999     std::pair<std::shared_ptr<Node>, OptimizationParams> to_save;
2000     {
2001       mutex_lock l(mu_);
2002       while (!save_thread_cancelled_ && save_buffer_.empty()) {
2003         save_cond_var_.wait(l);
2004       }
2005       if (save_thread_cancelled_) {
2006         return Status::OK();
2007       }
2008       to_save = save_buffer_.front();
2009       save_buffer_.pop_front();
2010     }
2011     string model_name =
2012         absl::StrCat("autotune_model_",
2013                      Hash64Combine(static_cast<uint64>(EnvTime::NowMicros()),
2014                                    reinterpret_cast<uint64>(this)));
2015     string fname = io::JoinPath(save_dir_, model_name);
2016     TF_RETURN_IF_ERROR(Save(fname, to_save.first, to_save.second));
2017     VLOG(2) << "Model was saved as " << fname;
2018   }
2019 }
2020 
2021 }  // namespace model
2022 }  // namespace data
2023 }  // namespace tensorflow
2024