• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/model.h"
17 
18 #include <memory>
19 
20 #include "absl/time/clock.h"
21 #include "tensorflow/core/framework/cancellation.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 
25 namespace tensorflow {
26 namespace data {
27 namespace model {
28 
29 constexpr int64_t Model::kOptimizationPeriodMinMs;
30 constexpr int64_t Model::kOptimizationPeriodMaxMs;
31 
32 namespace {
33 
34 // Helper function for node traversal that doesn't skip any nodes.
IsAnyNode(const std::shared_ptr<Node> node)35 inline bool IsAnyNode(const std::shared_ptr<Node> node) { return true; }
36 
37 // Helper function for node traversal that filters out nodes for which
38 // autotuning is disabled.
IsAutotuneNode(const std::shared_ptr<Node> node)39 inline bool IsAutotuneNode(const std::shared_ptr<Node> node) {
40   return node->autotune();
41 }
42 
43 // Wrapper for the square function to reduce verbosity.
Square(double x)44 inline double Square(double x) { return x * x; }
45 
46 // Collects "essential" parallelism parameters and buffer size parameters in the
47 // tree rooted in the given node. Which parallelism parameters are essential is
48 // determined by the relative processing time spent in the corresponding
49 // transformation. The collected parameters are returned via maps that map node
50 // names to their respective parameters.
CollectParameters(std::shared_ptr<Node> node,const Node::ModelParameters & parameters,Node::ModelParameters * parallelism_parameters,Node::ModelParameters * buffer_size_parameters)51 inline void CollectParameters(std::shared_ptr<Node> node,
52                               const Node::ModelParameters& parameters,
53                               Node::ModelParameters* parallelism_parameters,
54                               Node::ModelParameters* buffer_size_parameters) {
55   // Parallelism parameter is considered to be essential if the corresponding
56   // transformations's processing time is greater than essential rate times the
57   // average transformation self processing time.
58   constexpr double kEssentialRate = 0.3L;
59 
60   Node::NodeValues processing_times;
61   double processing_time = node->TotalProcessingTime(&processing_times);
62   double uniform_share =
63       processing_time / static_cast<double>(processing_times.size());
64   for (auto& pair : parameters) {
65     if (pair.second->name == kParallelism &&
66         processing_times[pair.first] > kEssentialRate * uniform_share) {
67       parallelism_parameters->push_back(pair);
68     } else if (pair.second->name == kBufferSize) {
69       buffer_size_parameters->push_back(pair);
70     }
71   }
72 }
73 
74 // Applies the gradient descent method once and updates the parameter values. If
75 // the new value is out of the range, bound it within the range between the
76 // minimal and maximum values.
UpdateParameterValues(const Node::ParameterGradients & gradients,Node::ModelParameters * parameters)77 inline void UpdateParameterValues(const Node::ParameterGradients& gradients,
78                                   Node::ModelParameters* parameters) {
79   // Gradient descent step size.
80   constexpr double kDescentStep = 0.1L;
81   double new_value;
82 
83   double max_abs_derivative = 1.0;
84   for (auto& pair : *parameters) {
85     if (std::round(pair.second->value) != pair.second->max) {
86       auto* gradient = gtl::FindOrNull(
87           gradients, std::make_pair(pair.first, pair.second->name));
88       if (gradient) {
89         max_abs_derivative = std::max(max_abs_derivative, std::abs(*gradient));
90       }
91     }
92   }
93   for (auto& pair : *parameters) {
94     auto* gradient = gtl::FindOrNull(
95         gradients, std::make_pair(pair.first, pair.second->name));
96     if (gradient) {
97       new_value =
98           pair.second->value - kDescentStep * (*gradient) / max_abs_derivative;
99       // Projection on a feasible interval.
100       if (new_value > pair.second->max) {
101         pair.second->value = pair.second->max;
102       } else if (new_value < pair.second->min) {
103         pair.second->value = pair.second->min;
104       } else {
105         pair.second->value = new_value;
106       }
107     }
108   }
109 }
110 
111 // Copies the parameter values (which are for optimization tuning) and updates
112 // the state values (which are for the input pipeline to follow).
UpdateStateValues(Node::ModelParameters * parameters)113 inline void UpdateStateValues(Node::ModelParameters* parameters) {
114   for (auto& pair : *parameters) {
115     auto& parameter = pair.second;
116     VLOG(2) << "Setting tunable parameter " << pair.first
117             << ":: " << parameter->name << " to " << parameter->value;
118     mutex_lock l(*parameter->state->mu);
119     parameter->state->value = parameter->value;
120     parameter->state->cond_var->notify_all();
121   }
122 }
123 
124 // Recursively produces protos for nodes in a subtree of `output` node and
125 // appends them to nodes of the given model.
ModelToProtoHelper(std::shared_ptr<Node> output,ModelProto * model)126 Status ModelToProtoHelper(std::shared_ptr<Node> output, ModelProto* model) {
127   model->set_output(output->id());
128   std::list<std::shared_ptr<Node>> to_serialize = {output};
129   auto& nodes = *model->mutable_nodes();
130   while (!to_serialize.empty()) {
131     const std::shared_ptr<Node> node = to_serialize.front();
132     to_serialize.pop_front();
133     TF_RETURN_IF_ERROR(node->ToProto(&(nodes[node->id()])));
134     for (auto input : node->inputs()) {
135       to_serialize.push_back(input);
136     }
137   }
138   return Status::OK();
139 }
140 
141 // Recursively produces node tree rooted in `output` from the given model proto.
ModelFromProtoHelper(ModelProto model,std::shared_ptr<Node> * output)142 Status ModelFromProtoHelper(ModelProto model, std::shared_ptr<Node>* output) {
143   TF_RETURN_IF_ERROR(Node::FromProto(model.nodes().at(model.output()),
144                                      /*output=*/nullptr, output));
145   std::list<std::shared_ptr<Node>> to_restore_inputs = {*output};
146   while (!to_restore_inputs.empty()) {
147     std::shared_ptr<Node> node = to_restore_inputs.front();
148     to_restore_inputs.pop_front();
149     for (int64_t input_id : model.nodes().at(node->id()).inputs()) {
150       std::shared_ptr<Node> input;
151       TF_RETURN_IF_ERROR(
152           Node::FromProto(model.nodes().at(input_id), node, &input));
153       node->add_input(input);
154       to_restore_inputs.push_back(input);
155     }
156   }
157   return Status::OK();
158 }
159 
160 // The first input of InterleaveMany corresponds to the input dataset whose
161 // elements are used to create the (derived) input datasets whose elements are
162 // interleaved as output.
163 //
164 // TODO(jsimsa): model the first input
165 class InterleaveMany : public Node {
166  public:
167   using Node::Node;
168 
~InterleaveMany()169   virtual ~InterleaveMany() {}
170 
171  protected:
Clone(std::shared_ptr<Node> output) const172   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
173       TF_SHARED_LOCKS_REQUIRED(mu_) {
174     return std::make_shared<InterleaveMany>(
175         Args{id_, name_, std::move(output)});
176   }
177 
InputTimeLocked(NodeValues * input_times) const178   void InputTimeLocked(NodeValues* input_times) const override
179       TF_SHARED_LOCKS_REQUIRED(mu_) {
180     double inherited_input_time;
181     if (output_) {
182       inherited_input_time = (*input_times)[output_->long_name()];
183     } else {
184       inherited_input_time = (*input_times)[kModelInputTimeKey];
185     }
186 
187     if (num_inputs() <= 1) {
188       (*input_times)[long_name()] = inherited_input_time;
189       return;
190     }
191     // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
192     // input time for InterleaveMany node to call one of the
193     // `(num_inputs() - 1)` input nodes (except first input) to return an
194     // element. Regardless of the `block_length` parameter of InterleaveMany
195     // node, the average input time for any of the `(num_inputs() - 1)` input
196     // nodes to be called is computed as:
197     double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
198                         static_cast<double>(num_inputs() - 1);
199     (*input_times)[long_name()] = input_time;
200   }
201 
202   // The output time is the sum of the self processing time and the average
203   // output time of inputs comprising the interleave "cycle".
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const204   void OutputTimeLocked(const NodeValues& input_times,
205                         ParameterGradients* gradients, NodeValues* output_times,
206                         NodeValues* output_time_gradients) const override
207       TF_SHARED_LOCKS_REQUIRED(mu_) {
208     double self_processing_time = SelfProcessingTimeLocked();
209     if (num_inputs() <= 1) {
210       (*output_times)[long_name()] = self_processing_time;
211       if (gradients) {
212         for (const auto& pair : CollectTunableParametersLocked()) {
213           gradients->erase(std::make_pair(pair.first, pair.second->name));
214         }
215       }
216       return;
217     }
218 
219     double inputs_output_time =
220         (OutputTimeForInputs(*output_times) -
221          (*output_times)[inputs_.front()->long_name()]) /
222         static_cast<double>(num_inputs() - 1);
223     if (gradients) {
224       for (const auto& pair : CollectTunableParametersLocked()) {
225         auto* gradient = gtl::FindOrNull(
226             *gradients, std::make_pair(pair.first, pair.second->name));
227         if (gradient) {
228           *gradient /= static_cast<double>(num_inputs() - 1);
229         }
230       }
231 
232       (*output_time_gradients)[long_name()] =
233           OutputTimeGradientsForInputs(*output_time_gradients) -
234           (*output_time_gradients)[inputs_.front()->long_name()];
235 
236       // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
237       // first input equal to 0 since its output time is excluded from
238       // computations.
239       for (auto& pair : inputs_.front()->CollectTunableParameters()) {
240         (*gradients)[std::make_pair(pair.first, pair.second->name)] = 0.0L;
241       }
242     }
243     (*output_times)[long_name()] = self_processing_time + inputs_output_time;
244   }
245 
246   // The processing time is the sum of the self processing time and the average
247   // processing time of inputs comprising the interleave "cycle".
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)248   void TotalProcessingTimeLocked(NodeValues* processing_times,
249                                  NodeValues* total_processing_times) override
250       TF_SHARED_LOCKS_REQUIRED(mu_) {
251     double self_processing_time = SelfProcessingTimeLocked();
252     if (processing_times) {
253       (*processing_times)[long_name()] = self_processing_time;
254     }
255     if (num_inputs() <= 1) {
256       (*total_processing_times)[long_name()] = self_processing_time;
257       return;
258     }
259     double inputs_processing_time =
260         (TotalProcessingTimeForInputs(*total_processing_times) -
261          (*total_processing_times)[inputs_.front()->long_name()]) /
262         static_cast<double>(num_inputs() - 1);
263     (*total_processing_times)[long_name()] =
264         self_processing_time + inputs_processing_time;
265   }
266 
ToProto(ModelProto::Node * node_proto) const267   Status ToProto(ModelProto::Node* node_proto) const {
268     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
269     node_proto->set_node_class(NodeClass::INTERLEAVE_MANY);
270     return Status::OK();
271   }
272 };
273 
274 // The first input of AsyncInterleaveMany corresponds to the input dataset whose
275 // elements are used to create the (derived) input datasets whose elements are
276 // interleaved as output.
277 //
278 // TODO(jsimsa): model the first input
279 class AsyncInterleaveMany : public Node {
280  public:
AsyncInterleaveMany(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)281   AsyncInterleaveMany(Node::Args args,
282                       std::vector<std::shared_ptr<Parameter>> parameters)
283       : Node(args) {
284     for (auto& parameter : parameters) {
285       parameters_[parameter->name] = std::move(parameter);
286     }
287   }
288 
~AsyncInterleaveMany()289   virtual ~AsyncInterleaveMany() {}
290 
291  protected:
Clone(std::shared_ptr<Node> output) const292   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
293       TF_SHARED_LOCKS_REQUIRED(mu_) {
294     std::vector<std::shared_ptr<Parameter>> parameters;
295     for (auto& pair : parameters_) {
296       parameters.push_back(pair.second);
297     }
298     return std::make_shared<AsyncInterleaveMany>(
299         Args{id_, name_, std::move(output)}, parameters);
300   }
301 
InputTimeLocked(NodeValues * input_times) const302   void InputTimeLocked(NodeValues* input_times) const override
303       TF_SHARED_LOCKS_REQUIRED(mu_) {
304     double inherited_input_time;
305     if (output_) {
306       inherited_input_time = (*input_times)[output_->long_name()];
307     } else {
308       inherited_input_time = (*input_times)[kModelInputTimeKey];
309     }
310 
311     if (num_inputs() <= 1) {
312       (*input_times)[long_name()] = inherited_input_time;
313       return;
314     }
315     // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
316     // input time for AsyncInterleaveMany node to call one of the
317     // `(num_inputs() - 1)` input nodes (except first input) to return an
318     // element. Regardless of the `block_length` parameter of
319     // AsyncInterleaveMany node, the average input time for any of the
320     // `(num_inputs() - 1)` input nodes to be called is computed as:
321     double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
322                         static_cast<double>(num_inputs() - 1);
323     (*input_times)[long_name()] = input_time;
324   }
325 
326   // The output time is the sum of self processing time and expected wait time
327   // from the buffer model estimated using
328   // `ComputeWaitTime(producer_time, consumer_time, parallelism, ...)`, where
329   // `producer_time` is the average output time of inputs comprising the
330   // interleave "cycle" divided by `parallelism`, `consumer_time` is the
331   // `input_time` specified through `input_times` divided by `num_inputs() - 1`,
332   // and if the node has parallelism parameter, then `buffer_size` is derived
333   // from `parallelism`.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const334   void OutputTimeLocked(const NodeValues& input_times,
335                         ParameterGradients* gradients, NodeValues* output_times,
336                         NodeValues* output_time_gradients) const override
337       TF_SHARED_LOCKS_REQUIRED(mu_) {
338     double self_processing_time = SelfProcessingTimeLocked();
339     if (num_inputs() <= 1) {
340       (*output_times)[long_name()] = self_processing_time;
341       if (gradients) {
342         for (const auto& pair : CollectTunableParametersLocked()) {
343           gradients->erase(std::make_pair(pair.first, pair.second->name));
344         }
345       }
346       return;
347     }
348 
349     double output_time, wait_time, consumer_time, producer_time;
350     double input_time = input_times.at(long_name());
351     consumer_time = input_time / static_cast<double>(num_inputs() - 1);
352     double parallelism = num_inputs() - 1;  // default to cycle length
353     auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
354     if (parameter) {
355       parallelism = std::min(parallelism, (*parameter)->value);
356     }
357     double output_time_for_inputs =
358         OutputTimeForInputs(*output_times) -
359         (*output_times)[inputs_.front()->long_name()];
360     producer_time = output_time_for_inputs /
361                     static_cast<double>(num_inputs() - 1) / parallelism;
362 
363     if (gradients) {
364       double producer_time_der = 0.0L;
365       double consumer_time_der = 0.0L;
366       double buffer_size_der = 0.0L;
367       wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
368                                   &producer_time_der, &consumer_time_der,
369                                   &buffer_size_der);
370       double inputs_time_der_sum =
371           OutputTimeGradientsForInputs(*output_time_gradients);
372       (*output_time_gradients)[long_name()] =
373           consumer_time_der +
374           producer_time_der * inputs_time_der_sum / parallelism;
375 
376       for (const auto& pair : CollectTunableParametersLocked()) {
377         auto* gradient = gtl::FindOrNull(
378             *gradients, std::make_pair(pair.first, pair.second->name));
379         if (gradient) {
380           *gradient *= (producer_time_der /
381                         static_cast<double>(num_inputs() - 1) / parallelism);
382         }
383       }
384 
385       // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
386       // first input equal to 0 since its output time is excluded from
387       // computations.
388       for (auto& pair : inputs_.front()->CollectTunableParameters()) {
389         (*gradients)[std::make_pair(pair.first, pair.second->name)] = 0.0L;
390       }
391       // Add derivative w.r.t. own parallelism parameter.
392       if (parameter && (*parameter)->state->tunable) {
393         (*gradients)[std::make_pair(long_name(), (*parameter)->name)] =
394             buffer_size_der - producer_time_der * producer_time / parallelism;
395       }
396     } else {
397       wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
398                                   /*producer_time_derivative=*/nullptr,
399                                   /*consumer_time_derivative=*/nullptr,
400                                   /*buffer_size_derivative=*/nullptr);
401     }
402     output_time = self_processing_time + wait_time;
403     (*output_times)[long_name()] = output_time;
404   }
405 
406   // The processing time is the sum of the self processing time and the average
407   // processing time of inputs comprising the interleave "cycle".
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)408   void TotalProcessingTimeLocked(NodeValues* processing_times,
409                                  NodeValues* total_processing_times) override
410       TF_SHARED_LOCKS_REQUIRED(mu_) {
411     double self_processing_time = SelfProcessingTimeLocked();
412     if (processing_times) {
413       (*processing_times)[long_name()] = self_processing_time;
414     }
415     if (num_inputs() <= 1) {
416       (*total_processing_times)[long_name()] = self_processing_time;
417       return;
418     }
419     double inputs_processing_time =
420         (TotalProcessingTimeForInputs(*total_processing_times) -
421          (*total_processing_times)[inputs_.front()->long_name()]) /
422         static_cast<double>(num_inputs() - 1);
423     (*total_processing_times)[long_name()] =
424         self_processing_time + inputs_processing_time;
425   }
426 
MaximumBufferedBytes() const427   double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
428     double result = 0;
429     auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
430     if (parameter) {
431       result += (*parameter)->value * AverageBufferedElementSize();
432     }
433     return result;
434   }
435 
ToProto(ModelProto::Node * node_proto) const436   Status ToProto(ModelProto::Node* node_proto) const {
437     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
438     node_proto->set_node_class(NodeClass::ASYNC_INTERLEAVE_MANY);
439     return Status::OK();
440   }
441 };
442 
443 class KnownRatio : public Node {
444  public:
KnownRatio(Node::Args args,double ratio)445   KnownRatio(Node::Args args, double ratio) : Node(args), ratio_(ratio) {}
446 
~KnownRatio()447   virtual ~KnownRatio() {}
448 
449  protected:
Clone(std::shared_ptr<Node> output) const450   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
451       TF_SHARED_LOCKS_REQUIRED(mu_) {
452     return std::make_shared<KnownRatio>(Args{id_, name_, std::move(output)},
453                                         ratio_);
454   }
455 
456   // The input time is the sum of inherited input time and self processing time,
457   // divided by `ratio_`.
InputTimeLocked(NodeValues * input_times) const458   void InputTimeLocked(NodeValues* input_times) const override
459       TF_SHARED_LOCKS_REQUIRED(mu_) {
460     double inherited_input_time;
461     if (output_) {
462       inherited_input_time = (*input_times)[output_->long_name()];
463     } else {
464       inherited_input_time = (*input_times)[kModelInputTimeKey];
465     }
466 
467     if (ratio_ == 0) {
468       (*input_times)[long_name()] = inherited_input_time;
469       return;
470     }
471     double input_time =
472         (inherited_input_time + SelfProcessingTimeLocked()) / ratio_;
473     (*input_times)[long_name()] = input_time;
474   }
475 
476   // The output time is the sum of the self processing time and the product of
477   // `ratio_` and the sum of output times of inputs.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const478   void OutputTimeLocked(const NodeValues& input_times,
479                         ParameterGradients* gradients, NodeValues* output_times,
480                         NodeValues* output_time_gradients) const override
481       TF_SHARED_LOCKS_REQUIRED(mu_) {
482     double self_processing_time = SelfProcessingTimeLocked();
483     if (ratio_ == 0) {
484       (*output_times)[long_name()] = self_processing_time;
485       if (gradients) {
486         for (const auto& pair : CollectTunableParametersLocked()) {
487           gradients->erase(std::make_pair(pair.first, pair.second->name));
488         }
489       }
490       return;
491     }
492     if (gradients) {
493       for (const auto& pair : CollectTunableParametersLocked()) {
494         auto* gradient = gtl::FindOrNull(
495             *gradients, std::make_pair(pair.first, pair.second->name));
496         if (gradient) {
497           *gradient *= ratio_;
498         }
499       }
500       (*output_time_gradients)[long_name()] =
501           OutputTimeGradientsForInputs(*output_time_gradients);
502     }
503     double inputs_output_time = ratio_ * OutputTimeForInputs(*output_times);
504     (*output_times)[long_name()] = self_processing_time + inputs_output_time;
505   }
506 
507   // The processing time is the sum of the self processing time and the product
508   // of `ratio_` and the sum of processing times of inputs.
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)509   void TotalProcessingTimeLocked(NodeValues* processing_times,
510                                  NodeValues* total_processing_times) override
511       TF_SHARED_LOCKS_REQUIRED(mu_) {
512     double self_processing_time = SelfProcessingTimeLocked();
513     if (processing_times) {
514       (*processing_times)[long_name()] = self_processing_time;
515     }
516     if (ratio_ == 0) {
517       (*total_processing_times)[long_name()] = self_processing_time;
518       return;
519     }
520     double inputs_processing_time =
521         ratio_ * TotalProcessingTimeForInputs(*total_processing_times);
522     (*total_processing_times)[long_name()] =
523         self_processing_time + inputs_processing_time;
524   }
525 
ToProto(ModelProto::Node * node_proto) const526   Status ToProto(ModelProto::Node* node_proto) const {
527     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
528     node_proto->set_node_class(NodeClass::KNOWN_RATIO);
529     node_proto->set_ratio(ratio_);
530     return Status::OK();
531   }
532 
533  private:
534   const double ratio_;
535 };
536 
537 class AsyncKnownRatio : public Node {
538  public:
AsyncKnownRatio(Node::Args args,double ratio,double memory_ratio,std::vector<std::shared_ptr<Parameter>> parameters)539   AsyncKnownRatio(Node::Args args, double ratio, double memory_ratio,
540                   std::vector<std::shared_ptr<Parameter>> parameters)
541       : Node(args), ratio_(ratio), memory_ratio_(memory_ratio) {
542     for (auto& parameter : parameters) {
543       parameters_[parameter->name] = std::move(parameter);
544     }
545   }
546 
~AsyncKnownRatio()547   virtual ~AsyncKnownRatio() {}
548 
549  protected:
Clone(std::shared_ptr<Node> output) const550   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
551       TF_SHARED_LOCKS_REQUIRED(mu_) {
552     std::vector<std::shared_ptr<Parameter>> parameters;
553     for (auto& pair : parameters_) {
554       parameters.push_back(pair.second);
555     }
556     return std::make_shared<AsyncKnownRatio>(
557         Args{id_, name_, std::move(output)}, ratio_, memory_ratio_, parameters);
558   }
559 
560   // The input time is the sum of inherited input time and parallelism adjusted
561   // self processing time, divided by `ratio_`.
InputTimeLocked(NodeValues * input_times) const562   void InputTimeLocked(NodeValues* input_times) const override
563       TF_SHARED_LOCKS_REQUIRED(mu_) {
564     double inherited_input_time;
565     if (output_) {
566       inherited_input_time = (*input_times)[output_->long_name()];
567     } else {
568       inherited_input_time = (*input_times)[kModelInputTimeKey];
569     }
570     double parallelism = 1.0;
571     auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
572     if (parallelism_parameter) {
573       parallelism = (*parallelism_parameter)->value;
574     }
575 
576     if (ratio_ == 0.0) {
577       (*input_times)[long_name()] =
578           inherited_input_time + SelfProcessingTimeLocked() / parallelism;
579       return;
580     }
581     double input_time =
582         (inherited_input_time + SelfProcessingTimeLocked() / parallelism) /
583         ratio_;
584     (*input_times)[long_name()] = input_time;
585   }
586 
587   // The output time is the sum of parallelism adjusted self processing time and
588   // expected wait time from the buffer model estimated using
589   // `ComputeWaitTime(producer_time, consumer_time, parallelism, ...)`, where
590   // `producer_time` is the product of `ratio_` and the sum of output times of
591   // inputs, `consumer_time` is the product of `ratio_` and the `input_time`
592   // specified through `input_times` (since for each element stored in the
593   // buffer, the inputs need to be called `ratio_` times), and if the node has
594   // parallelism parameter, then `buffer_size` is derived from `parallelism`.
595   //
596   // Current implementation assumes that there is at most 1 parameter per node.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const597   void OutputTimeLocked(const NodeValues& input_times,
598                         ParameterGradients* gradients, NodeValues* output_times,
599                         NodeValues* output_time_gradients) const override
600       TF_SHARED_LOCKS_REQUIRED(mu_) {
601     double parallelism = 1.0;
602     double buffer_size = 0.0;
603     auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
604     auto* buffer_size_parameter = gtl::FindOrNull(parameters_, kBufferSize);
605     if (parallelism_parameter) {
606       parallelism = (*parallelism_parameter)->value;
607       if (ratio_ == 0) {
608         buffer_size = parallelism;
609       } else {
610         // Currently, MapAndBatch is the only transformation creates
611         // AsyncKnownRatio nodes with ratio >= 1. For MapAndBatch, we create
612         // `parallelism` threads to apply the function on elements from input
613         // dataset, while one element in the buffer actually corresponds to
614         // `ratio_` elements from input dataset. So we adjust the `buffer_size`
615         // by dividing `ratio_`.
616         buffer_size = parallelism / ratio_;
617       }
618     } else if (buffer_size_parameter) {
619       buffer_size = (*buffer_size_parameter)->value;
620     }
621     double self_processing_time = SelfProcessingTimeLocked();
622     double output_time, wait_time, consumer_time, producer_time;
623     double input_time = input_times.at(long_name());
624 
625     if (ratio_ == 0) {
626       consumer_time = input_time;
627       producer_time = 0.0L;
628       if (gradients) {
629         for (const auto& pair : CollectTunableParametersLocked()) {
630           gradients->erase(std::make_pair(pair.first, pair.second->name));
631         }
632 
633         double producer_time_der = 0.0L;
634         double consumer_time_der = 0.0L;
635         double buffer_size_der = 0.0L;
636         wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
637                                     &producer_time_der, &consumer_time_der,
638                                     &buffer_size_der);
639         (*output_time_gradients)[long_name()] = consumer_time_der;
640         if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
641           (*gradients)[std::make_pair(long_name(),
642                                       (*parallelism_parameter)->name)] =
643               -(1.0L + consumer_time_der) * self_processing_time /
644                   Square(parallelism) +
645               buffer_size_der;
646         } else if (buffer_size_parameter &&
647                    (*buffer_size_parameter)->state->tunable) {
648           (*gradients)[std::make_pair(
649               long_name(), (*buffer_size_parameter)->name)] = buffer_size_der;
650         }
651       } else {
652         wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
653                                     /*producer_time_derivative=*/nullptr,
654                                     /*consumer_time_derivative=*/nullptr,
655                                     /*buffer_size_derivative=*/nullptr);
656       }
657       output_time = self_processing_time / parallelism + wait_time;
658       (*output_times)[long_name()] = output_time;
659       return;
660     }
661 
662     consumer_time = input_time * ratio_;
663     producer_time = ratio_ * OutputTimeForInputs(*output_times);
664     if (gradients) {
665       double producer_time_der = 0.0L;
666       double consumer_time_der = 0.0L;
667       double buffer_size_der = 0.0L;
668       wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
669                                   &producer_time_der, &consumer_time_der,
670                                   &buffer_size_der);
671       double inputs_time_der_sum =
672           OutputTimeGradientsForInputs(*output_time_gradients);
673       (*output_time_gradients)[long_name()] =
674           consumer_time_der + producer_time_der * inputs_time_der_sum;
675 
676       for (const auto& pair : CollectTunableParametersLocked()) {
677         auto* gradient = gtl::FindOrNull(
678             *gradients, std::make_pair(pair.first, pair.second->name));
679         if (gradient) {
680           *gradient *= (ratio_ * producer_time_der);
681         }
682       }
683 
684       // Add derivative w.r.t. own parameter if it's tunable.
685       if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
686         (*gradients)[std::make_pair(long_name(),
687                                     (*parallelism_parameter)->name)] =
688             buffer_size_der / ratio_ -
689             (1.0L + consumer_time_der +
690              producer_time_der * inputs_time_der_sum) *
691                 self_processing_time / Square(parallelism);
692       } else if (buffer_size_parameter &&
693                  (*buffer_size_parameter)->state->tunable) {
694         (*gradients)[std::make_pair(
695             long_name(), (*buffer_size_parameter)->name)] = buffer_size_der;
696       }
697     } else {
698       wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
699                                   /*producer_time_derivative=*/nullptr,
700                                   /*consumer_time_derivative=*/nullptr,
701                                   /*buffer_size_derivative=*/nullptr);
702     }
703     output_time = self_processing_time / parallelism + wait_time;
704     (*output_times)[long_name()] = output_time;
705   }
706 
707   // The processing time is the sum of the self processing time and the product
708   // of `ratio_` and the sum of processing times of inputs.
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)709   void TotalProcessingTimeLocked(NodeValues* processing_times,
710                                  NodeValues* total_processing_times) override
711       TF_SHARED_LOCKS_REQUIRED(mu_) {
712     double self_processing_time = SelfProcessingTimeLocked();
713     if (processing_times) {
714       (*processing_times)[long_name()] = self_processing_time;
715     }
716     if (ratio_ == 0) {
717       (*total_processing_times)[long_name()] = self_processing_time;
718       return;
719     }
720     double inputs_processing_time =
721         ratio_ * TotalProcessingTimeForInputs(*total_processing_times);
722     (*total_processing_times)[long_name()] =
723         self_processing_time + inputs_processing_time;
724   }
725 
MaximumBufferedBytes() const726   double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
727     double result = 0;
728     auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
729     if (!parameter) {
730       parameter = gtl::FindOrNull(parameters_, kParallelism);
731     }
732 
733     if (parameter) {
734       if (memory_ratio_ == 0) {
735         result += (*parameter)->value * AverageBufferedElementSize();
736       } else {
737         // The estimation is currently not accurate for MapAndBatchDataset for
738         // the maximum buffer size does not match `num_parallel_calls`
739         // parameter.
740         result +=
741             (*parameter)->value * AverageBufferedElementSize() / memory_ratio_;
742       }
743     }
744     return result;
745   }
746 
ToProto(ModelProto::Node * node_proto) const747   Status ToProto(ModelProto::Node* node_proto) const {
748     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
749     node_proto->set_node_class(NodeClass::ASYNC_KNOWN_RATIO);
750     node_proto->set_ratio(ratio_);
751     node_proto->set_memory_ratio(memory_ratio_);
752     return Status::OK();
753   }
754 
755  private:
756   // Identifies how many input elements need to be created to construct an
757   // element for the dataset.
758   //
759   // Currently the value is 1 for PrefetchDataset and ParallelMapDataset,
760   // batch_size for MapAndBatchDataset and ParallelBatchDataset.
761   const double ratio_;
762   // For parallelism nodes, identifies how many parallelism calls are introduced
763   // by one buffered element. The value is defined to correctly estimate RAM
764   // budget bound with given num_parallel_calls (or buffer_size) combined with
765   // the estimated average size of buffered elements.
766   const double memory_ratio_;
767 };
768 
769 class UnknownRatio : public Node {
770  public:
771   using Node::Node;
772 
~UnknownRatio()773   virtual ~UnknownRatio() {}
774 
775  protected:
Clone(std::shared_ptr<Node> output) const776   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
777       TF_SHARED_LOCKS_REQUIRED(mu_) {
778     return std::make_shared<UnknownRatio>(Args{id_, name_, std::move(output)});
779   }
780 
781   // The input time is the sum of inherited input time and self processing time,
782   // divided by the ratio estimate.
InputTimeLocked(NodeValues * input_times) const783   void InputTimeLocked(NodeValues* input_times) const override
784       TF_SHARED_LOCKS_REQUIRED(mu_) {
785     double inherited_input_time;
786     if (output_) {
787       inherited_input_time = (*input_times)[output_->long_name()];
788     } else {
789       inherited_input_time = (*input_times)[kModelInputTimeKey];
790     }
791 
792     if (num_elements_ == 0 || inputs_.empty() ||
793         inputs_.front()->num_elements() == 0) {
794       (*input_times)[long_name()] = inherited_input_time;
795       return;
796     }
797     std::shared_ptr<Node> input = inputs_.front();
798     double ratio = static_cast<double>(input->num_elements()) /
799                    static_cast<double>(num_elements_);
800     double input_time =
801         (inherited_input_time + SelfProcessingTimeLocked()) / ratio;
802     (*input_times)[long_name()] = input_time;
803   }
804 
805   // The output time is the sum of the self processing time and the product of
806   // the ratio estimate and the sum of output times of inputs.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const807   void OutputTimeLocked(const NodeValues& input_times,
808                         ParameterGradients* gradients, NodeValues* output_times,
809                         NodeValues* output_time_gradients) const override
810       TF_SHARED_LOCKS_REQUIRED(mu_) {
811     double self_processing_time = SelfProcessingTimeLocked();
812     if (num_elements_ == 0 || inputs_.empty() ||
813         inputs_.front()->num_elements() == 0) {
814       (*output_times)[long_name()] = self_processing_time;
815       if (gradients) {
816         for (const auto& pair : CollectTunableParametersLocked()) {
817           gradients->erase(std::make_pair(pair.first, pair.second->name));
818         }
819       }
820       return;
821     }
822     // TODO(jsimsa): The current implementation assumes that the number of input
823     // elements consumed per output is the same across all inputs.
824     double ratio = static_cast<double>(inputs_.front()->num_elements()) /
825                    static_cast<double>(num_elements_);
826     if (gradients) {
827       for (const auto& pair : CollectTunableParametersLocked()) {
828         auto* gradient = gtl::FindOrNull(
829             *gradients, std::make_pair(pair.first, pair.second->name));
830         if (gradient) {
831           *gradient *= ratio;
832         }
833       }
834       (*output_time_gradients)[long_name()] =
835           OutputTimeGradientsForInputs(*output_time_gradients);
836     }
837     double inputs_output_time = ratio * OutputTimeForInputs(*output_times);
838     (*output_times)[long_name()] = self_processing_time + inputs_output_time;
839   }
840 
841   // The processing time is the sum of the self processing time and the product
842   // of the ratio estimate and the sum of processing times of inputs.
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)843   void TotalProcessingTimeLocked(
844       absl::flat_hash_map<string, double>* processing_times,
845       absl::flat_hash_map<string, double>* total_processing_times) override
846       TF_SHARED_LOCKS_REQUIRED(mu_) {
847     double self_processing_time = SelfProcessingTimeLocked();
848     if (processing_times) {
849       (*processing_times)[long_name()] = self_processing_time;
850     }
851     if (inputs_.empty() || num_elements_ == 0) {
852       (*total_processing_times)[long_name()] = self_processing_time;
853       return;
854     }
855     // TODO(jsimsa): The current implementation assumes that the number of input
856     // elements consumed per output is the same across all inputs.
857     std::shared_ptr<Node> input = inputs_.front();
858     double ratio = static_cast<double>(input->num_elements()) /
859                    static_cast<double>(num_elements_);
860     double inputs_processing_time =
861         ratio * TotalProcessingTimeForInputs(*total_processing_times);
862     (*total_processing_times)[long_name()] =
863         self_processing_time + inputs_processing_time;
864   }
865 
ToProto(ModelProto::Node * node_proto) const866   Status ToProto(ModelProto::Node* node_proto) const {
867     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
868     node_proto->set_node_class(NodeClass::UNKNOWN_RATIO);
869     return Status::OK();
870   }
871 };
872 
873 class Unknown : public Node {
874  public:
875   using Node::Node;
876 
~Unknown()877   virtual ~Unknown() {}
878 
879  protected:
Clone(std::shared_ptr<Node> output) const880   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
881       TF_SHARED_LOCKS_REQUIRED(mu_) {
882     return std::make_shared<Unknown>(Args{id_, name_, std::move(output)});
883   }
884 
885   // The input time is the inherited input time.
InputTimeLocked(NodeValues * input_times) const886   void InputTimeLocked(NodeValues* input_times) const override
887       TF_SHARED_LOCKS_REQUIRED(mu_) {
888     double inherited_input_time;
889     if (output_) {
890       inherited_input_time = (*input_times)[output_->long_name()];
891     } else {
892       inherited_input_time = (*input_times)[kModelInputTimeKey];
893     }
894     (*input_times)[long_name()] = inherited_input_time;
895   }
896 
897   // The output time is the sum of output times of inputs.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const898   void OutputTimeLocked(const NodeValues& input_times,
899                         ParameterGradients* gradients, NodeValues* output_times,
900                         NodeValues* output_time_gradients) const override
901       TF_SHARED_LOCKS_REQUIRED(mu_) {
902     (*output_times)[long_name()] = OutputTimeForInputs(*output_times);
903     if (gradients) {
904       (*output_time_gradients)[long_name()] =
905           OutputTimeGradientsForInputs(*output_time_gradients);
906     }
907   }
908 
909   // The processing time is the sum of processing times of inputs.
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)910   void TotalProcessingTimeLocked(NodeValues* processing_times,
911                                  NodeValues* total_processing_times) override
912       TF_SHARED_LOCKS_REQUIRED(mu_) {
913     if (processing_times) {
914       (*processing_times)[long_name()] = SelfProcessingTimeLocked();
915     }
916     (*total_processing_times)[long_name()] =
917         TotalProcessingTimeForInputs(*total_processing_times);
918   }
919 
ToProto(ModelProto::Node * node_proto) const920   Status ToProto(ModelProto::Node* node_proto) const {
921     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
922     node_proto->set_node_class(NodeClass::UNKNOWN);
923     return Status::OK();
924   }
925 };
926 
927 }  // namespace
928 
929 thread_local int64_t Node::work_start_;
930 
MakeParameter(const string & name,std::shared_ptr<SharedState> state,double min,double max)931 std::shared_ptr<Parameter> MakeParameter(const string& name,
932                                          std::shared_ptr<SharedState> state,
933                                          double min, double max) {
934   return std::make_shared<Parameter>(name, state, min, max);
935 }
936 
MakeInterleaveManyNode(Node::Args args)937 std::shared_ptr<Node> MakeInterleaveManyNode(Node::Args args) {
938   return std::make_shared<InterleaveMany>(std::move(args));
939 }
940 
MakeAsyncInterleaveManyNode(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)941 std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
942     Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters) {
943   return std::make_shared<AsyncInterleaveMany>(std::move(args),
944                                                std::move(parameters));
945 }
946 
MakeKnownRatioNode(Node::Args args,double ratio)947 std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio) {
948   return std::make_shared<KnownRatio>(std::move(args), ratio);
949 }
950 
MakeAsyncKnownRatioNode(Node::Args args,double ratio,double memory_ratio,std::vector<std::shared_ptr<Parameter>> parameters)951 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
952     Node::Args args, double ratio, double memory_ratio,
953     std::vector<std::shared_ptr<Parameter>> parameters) {
954   return std::make_shared<AsyncKnownRatio>(std::move(args), ratio, memory_ratio,
955                                            std::move(parameters));
956 }
957 
MakeAsyncKnownRatioNode(Node::Args args,double ratio,std::vector<std::shared_ptr<Parameter>> parameters)958 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
959     Node::Args args, double ratio,
960     std::vector<std::shared_ptr<Parameter>> parameters) {
961   return MakeAsyncKnownRatioNode(std::move(args), /*ratio=*/ratio,
962                                  /*memory_ratio=*/ratio, std::move(parameters));
963 }
964 
MakeSourceNode(Node::Args args)965 std::shared_ptr<Node> MakeSourceNode(Node::Args args) {
966   return MakeKnownRatioNode(std::move(args), 0);
967 }
968 
MakeUnknownRatioNode(Node::Args args)969 std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args) {
970   return std::make_shared<UnknownRatio>(std::move(args));
971 }
972 
MakeUnknownNode(Node::Args args)973 std::shared_ptr<Node> MakeUnknownNode(Node::Args args) {
974   return std::make_shared<Unknown>(std::move(args));
975 }
976 
ComputeWaitTime(const double & producer_time,const double & consumer_time,const double & buffer_size,double * producer_time_derivative,double * consumer_time_derivative,double * buffer_size_derivative)977 double Node::ComputeWaitTime(const double& producer_time,
978                              const double& consumer_time,
979                              const double& buffer_size,
980                              double* producer_time_derivative,
981                              double* consumer_time_derivative,
982                              double* buffer_size_derivative) {
983   // If we set x=`consumer_time`, y=`producer_time`, n=`buffer_size`,
984   // p=`p_buffer_empty`, T=`wait_time`, then we have:
985   // if y = 0, then p = 0;
986   // elif x = 0, then p = 1;
987   // elif x = y, then p = 1 / (n+1);
988   // else p = [1 - x/y] / [1 - power(x/y, n+1)].
989   //
990   // We also have T = p * y, and derivatives of T w.r.t. x, y, n are computed:
991   // dT/dx = dp/dx * y,
992   // dT/dy = p + dp/dy * y,
993   // dT/dn = dp/dn * y.
994   // Then the remaining work is to compute dp/dx, dp/dy, dp/dn by considering
995   // different cases and substitute the values into above formulas.
996 
997   // Case 1: if producer is infinitely fast. The buffer will always be full.
998   // Wait time will always be 0.
999   if (producer_time == 0) {
1000     if (producer_time_derivative) {
1001       // Note a common error is `*producer_time_derivative = 0` since p=0 on the
1002       // line y=0 doesn't imply dp/dy = 0 there. Actually to compute dp/dy at
1003       // (x,0), we need to consider lim_{dy->0+} [p(x,dy)-p(x,0)] / dy, where
1004       // p(x,0)=0 and p(x,dy) = [1 - x/dy] / [1 - power(x/dy, n+1)].
1005       if (buffer_size == 0 || consumer_time == 0) {
1006         *producer_time_derivative = 1.0L;
1007       } else {
1008         *producer_time_derivative = 0.0L;
1009       }
1010     }
1011     if (consumer_time_derivative) {
1012       *consumer_time_derivative = 0.0L;
1013     }
1014     if (buffer_size_derivative) {
1015       *buffer_size_derivative = 0.0L;
1016     }
1017     return 0.0L;
1018   }
1019 
1020   // Case 2: if consumer is infinitely fast. Wait time is always the time to
1021   // produce an output.
1022   if (consumer_time == 0) {
1023     if (producer_time_derivative) {
1024       *producer_time_derivative = 1.0L;
1025     }
1026     if (consumer_time_derivative) {
1027       // Note a common error is `*consumer_time_derivative = 0` since p=1 on the
1028       // line x=0 doesn't imply dp/dx = 0 there. Actually to compute dp/dx at
1029       // (0,y), we need to consider lim_{dx->0+} [p(dx,y)-p(0,y)] / dx, where
1030       // p(0,y)=1, p(dx,y) = [1 - dx/y] / [1 - power(dx/y, n+1)] if y!=0.
1031       if (buffer_size == 0) {
1032         *consumer_time_derivative = 0.0L;
1033       } else {
1034         *consumer_time_derivative = -1.0L;
1035       }
1036     }
1037     if (buffer_size_derivative) {
1038       *buffer_size_derivative = 0.0L;
1039     }
1040     return producer_time;
1041   }
1042 
1043   // Case 3: the consumer and the producer are equally fast. Expected wait time
1044   // decreases linearly with the size of the buffer.
1045   if (consumer_time == producer_time) {
1046     const double p_buffer_empty = 1.0L / (buffer_size + 1.0L);
1047     const double p_buffer_empty_der =
1048         -buffer_size / (2.0L * buffer_size + 2.0L);
1049     if (producer_time_derivative) {
1050       // Note a common error is `*producer_time_derivative = p_buffer_empty`
1051       // since p=1/(n+1) on the line x=y doesn't imply dp/dy = 0 there. Actually
1052       // to compute dp/dy at (y,y), we need to consider
1053       // lim_{dy->0} [p(y,y+dy)-p(y,y)] / dy, where p(y,y)=1/(n+1),
1054       // p(y,y+dy) = [1 - y/(y+dy)] / [1 - power(y/(y+dy), n+1)].
1055       *producer_time_derivative = p_buffer_empty - p_buffer_empty_der;
1056     }
1057     if (consumer_time_derivative) {
1058       // Note a common error is `*consumer_time_derivative = 0` since
1059       // p=1/(n+1) on the line x=y doesn't imply dp/dx = 0 there. Actually to
1060       // compute dp/dx at (x,x), we need to consider
1061       // lim_{dx->0} [p(x+dx,x)-p(x,x)] / dx, where p(x,x)=1/(n+1),
1062       // p(x+dx,x) = [1 - (x+dx)/x] / [1 - power((x+dx)/x, n+1)].
1063       *consumer_time_derivative = p_buffer_empty_der;
1064     }
1065     if (buffer_size_derivative) {
1066       *buffer_size_derivative = -producer_time / Square(buffer_size + 1.0L);
1067     }
1068     return p_buffer_empty * producer_time;
1069   }
1070 
1071   // Case 4: the consumer is slower than the producer and neither is infinitely
1072   // fast. Case 4 and Case 5 actually follow same formula. Separate them for
1073   // numerical computation reasons.
1074   if (consumer_time > producer_time) {
1075     const double ratio = producer_time / consumer_time;
1076     const double ratio_pow = std::pow(ratio, buffer_size);
1077     const double p_buffer_empty =
1078         ratio_pow * (1.0L - ratio) / (1.0L - ratio * ratio_pow);
1079     const double p_buffer_empty_der =
1080         (buffer_size - (buffer_size + 1.0L) * ratio + ratio_pow * ratio) *
1081         ratio_pow / ratio / Square(1.0L - ratio_pow * ratio);
1082     if (producer_time_derivative) {
1083       *producer_time_derivative = p_buffer_empty + p_buffer_empty_der * ratio;
1084     }
1085     if (consumer_time_derivative) {
1086       *consumer_time_derivative = -p_buffer_empty_der * Square(ratio);
1087     }
1088     if (buffer_size_derivative) {
1089       *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1090                                 std::log(ratio) * producer_time;
1091     }
1092     return p_buffer_empty * producer_time;
1093   }
1094 
1095   // Case 5: the producer is slower than the consumer and neither is infinitely
1096   // fast.
1097   const double ratio = consumer_time / producer_time;
1098   const double ratio_pow = std::pow(ratio, buffer_size);
1099   const double p_buffer_empty = (1.0L - ratio) / (1.0L - ratio_pow * ratio);
1100   const double p_buffer_empty_der =
1101       ((buffer_size + 1.0L - buffer_size * ratio) * ratio_pow - 1.0L) /
1102       Square(1.0L - ratio_pow * ratio);
1103   if (producer_time_derivative) {
1104     *producer_time_derivative = p_buffer_empty - p_buffer_empty_der * ratio;
1105   }
1106   if (consumer_time_derivative) {
1107     *consumer_time_derivative = p_buffer_empty_der;
1108   }
1109   if (buffer_size_derivative) {
1110     *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1111                               ratio_pow * ratio * std::log(ratio) *
1112                               producer_time;
1113   }
1114   return p_buffer_empty * producer_time;
1115 }
1116 
CollectTunableParametersLocked() const1117 Node::ModelParameters Node::CollectTunableParametersLocked() const {
1118   Node::ModelParameters parameters;
1119   // Collect tunable parameters from the leaves of the nodes tree to the root.
1120   for (const auto& node :
1121        CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1122     tf_shared_lock l(node->mu_);
1123     node->CollectTunableParametersHelper(&parameters);
1124   }
1125   CollectTunableParametersHelper(&parameters);
1126   return parameters;
1127 }
1128 
CollectTunableParameters() const1129 Node::ModelParameters Node::CollectTunableParameters() const {
1130   tf_shared_lock l(mu_);
1131   return CollectTunableParametersLocked();
1132 }
1133 
DebugString() const1134 string Node::DebugString() const {
1135   absl::flat_hash_map<string, string> debug_strings;
1136   tf_shared_lock l(mu_);
1137   // Build up the debug string from the leaves of the nodes tree to the root.
1138   for (const auto& node :
1139        CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1140     tf_shared_lock l(node->mu_);
1141     node->DebugStringHelper(&debug_strings);
1142   }
1143   DebugStringHelper(&debug_strings);
1144 
1145   return debug_strings[long_name()];
1146 }
1147 
FlushMetrics()1148 void Node::FlushMetrics() {
1149   if (!record_metrics_) {
1150     return;
1151   }
1152   metrics_.record_bytes_consumed(bytes_consumed_);
1153   metrics_.record_bytes_produced(bytes_produced_);
1154   metrics_.record_num_elements(num_elements_);
1155 }
1156 
OutputTime(Node::NodeValues * input_times,Node::ParameterGradients * gradients) const1157 double Node::OutputTime(Node::NodeValues* input_times,
1158                         Node::ParameterGradients* gradients) const {
1159   // To store the output time gradient w.r.t. input time (if `gradients` is not
1160   // `nullptr`) and the output time for each node.
1161   Node::NodeValues output_time_gradients, output_times;
1162   tf_shared_lock l(mu_);
1163   auto nodes = CollectNodes(TraversalOrder::BFS, IsAutotuneNode);
1164 
1165   // Computes and stores input time for each node from the root to leaves of the
1166   // nodes tree.
1167   InputTimeLocked(input_times);
1168   for (const auto& node : nodes) {
1169     tf_shared_lock l(node->mu_);
1170     node->InputTimeLocked(input_times);
1171   }
1172 
1173   std::reverse(nodes.begin(), nodes.end());
1174   // Computes and stores the output time and output time gradient w.r.t. input
1175   // time (if `gradients` is not `nullptr`) for each node from leaves of the
1176   // nodes tree to the root.
1177   for (const auto& node : nodes) {
1178     tf_shared_lock l(node->mu_);
1179     node->OutputTimeLocked(*input_times, gradients, &output_times,
1180                            &output_time_gradients);
1181   }
1182   OutputTimeLocked(*input_times, gradients, &output_times,
1183                    &output_time_gradients);
1184 
1185   return output_times[long_name()];
1186 }
1187 
Snapshot() const1188 std::shared_ptr<Node> Node::Snapshot() const {
1189   NodePairList node_pairs;
1190   auto result = SnapshotHelper(nullptr, &node_pairs);
1191 
1192   while (!node_pairs.empty()) {
1193     auto node_pair = node_pairs.front();
1194     node_pairs.pop_front();
1195     std::shared_ptr<Node> current = node_pair.first,
1196                           cloned_output = node_pair.second;
1197     cloned_output->add_input(
1198         current->SnapshotHelper(cloned_output, &node_pairs));
1199   }
1200   return result;
1201 }
1202 
SelfProcessingTime() const1203 double Node::SelfProcessingTime() const {
1204   tf_shared_lock l(mu_);
1205   return SelfProcessingTimeLocked();
1206 }
1207 
TotalBufferedBytes() const1208 double Node::TotalBufferedBytes() const {
1209   Node::NodeValues total_bytes;
1210   tf_shared_lock l(mu_);
1211   // Compute total buffered bytes from the leaves of the nodes tree to the root.
1212   for (const auto& node :
1213        CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1214     tf_shared_lock l(node->mu_);
1215     node->TotalBufferedBytesHelper(&total_bytes);
1216   }
1217   TotalBufferedBytesHelper(&total_bytes);
1218 
1219   return total_bytes[long_name()];
1220 }
1221 
TotalMaximumBufferedBytes() const1222 double Node::TotalMaximumBufferedBytes() const {
1223   Node::NodeValues total_bytes;
1224   tf_shared_lock l(mu_);
1225   // Compute total maximum buffered bytes from the leaves of the nodes tree
1226   // to the root.
1227   for (const auto& node :
1228        CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1229     tf_shared_lock l(node->mu_);
1230     node->TotalMaximumBufferedBytesHelper(&total_bytes);
1231   }
1232   TotalMaximumBufferedBytesHelper(&total_bytes);
1233 
1234   return total_bytes[long_name()];
1235 }
1236 
TotalProcessingTime(Node::NodeValues * processing_times)1237 double Node::TotalProcessingTime(Node::NodeValues* processing_times) {
1238   // Create a hash map to store the per-element CPU time spent in the subtree
1239   // rooted in each node.
1240   Node::NodeValues total_processing_times;
1241   tf_shared_lock l(mu_);
1242 
1243   // Computes per-element CPU time spent in the subtree rooted in the node from
1244   // the leaves of the nodes tree to the root.
1245   for (const auto& node :
1246        CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1247     tf_shared_lock l(node->mu_);
1248     node->TotalProcessingTimeLocked(processing_times, &total_processing_times);
1249   }
1250   TotalProcessingTimeLocked(processing_times, &total_processing_times);
1251 
1252   return total_processing_times[long_name()];
1253 }
1254 
AverageBufferedElementSize() const1255 double Node::AverageBufferedElementSize() const {
1256   DCHECK_GE(num_elements_, 0);
1257   DCHECK_GE(buffered_elements_, 0);
1258   if (num_elements_ <= 0) {
1259     if (buffered_elements_ <= 0) {
1260       // If there are no produced elements or buffered elements recorded, return
1261       // 0.
1262       return 0;
1263     }
1264     // If there are no produced elements but some buffered elements, return the
1265     // average size of all buffered elements.
1266     return static_cast<double>(buffered_bytes_) /
1267            static_cast<double>(buffered_elements_);
1268   }
1269 
1270   if (buffered_elements_ <= 0) {
1271     // If there are no buffered elements but some produced elements, return the
1272     // average size of all produced elements.
1273     return static_cast<double>(bytes_produced_) /
1274            static_cast<double>(num_elements_);
1275   }
1276 
1277   // Otherwise, return the mean value of average size of all produced elements
1278   // and average size of all buffered elements.
1279   return (static_cast<double>(bytes_produced_) /
1280               static_cast<double>(num_elements_) +
1281           static_cast<double>(buffered_bytes_) /
1282               static_cast<double>(buffered_elements_)) /
1283          2.0;
1284 }
1285 
OutputTimeForInputs(const Node::NodeValues & output_times) const1286 double Node::OutputTimeForInputs(const Node::NodeValues& output_times) const {
1287   double sum = 0;
1288   for (auto& input : inputs_) {
1289     // Inputs for which autotuning is disabled are excluded.
1290     if (input->autotune()) {
1291       sum += output_times.at(input->long_name());
1292     }
1293   }
1294   return sum;
1295 }
1296 
OutputTimeGradientsForInputs(const Node::NodeValues & output_time_gradients) const1297 double Node::OutputTimeGradientsForInputs(
1298     const Node::NodeValues& output_time_gradients) const {
1299   double sum = 0;
1300   for (auto& input : inputs_) {
1301     // Inputs for which autotuning is disabled are excluded.
1302     if (input->autotune()) {
1303       sum +=
1304           gtl::FindWithDefault(output_time_gradients, input->long_name(), 0.0L);
1305     }
1306   }
1307   return sum;
1308 }
1309 
TotalProcessingTimeForInputs(const Node::NodeValues & total_processing_times)1310 double Node::TotalProcessingTimeForInputs(
1311     const Node::NodeValues& total_processing_times) {
1312   // If the number of elements produced by an input is smaller than this
1313   // constant, then its processing time is estimated using a weighted average
1314   // of the empirical processing time and processing time history.
1315   constexpr int kNumElementsThreshold = 30;
1316 
1317   // Identifies the minimum number of input processing times to collect
1318   // before the processing time history is used as a prior.
1319   constexpr int kCountThreshold = 30;
1320 
1321   double sum = 0;
1322   for (auto& input : inputs_) {
1323     // Inputs for which autotuning is disabled are excluded.
1324     if (input->autotune()) {
1325       double input_processing_time =
1326           total_processing_times.at(input->long_name());
1327       int64_t num_elements = input->num_elements();
1328       if (num_elements < kNumElementsThreshold) {
1329         if (input_processing_time_count_ < kCountThreshold) {
1330           sum += input_processing_time;
1331         } else {
1332           // The fewer elements the input has produced so far, the more weight
1333           // is assigned to the prior to reduce volatility.
1334           double prior_weight = 1.0L / static_cast<double>(2 << num_elements);
1335           double prior =
1336               input_processing_time_sum_ / input_processing_time_count_;
1337           sum += (1.0L - prior_weight) * input_processing_time +
1338                  prior_weight * prior;
1339         }
1340       } else {
1341         sum += input_processing_time;
1342         input_processing_time_count_++;
1343         input_processing_time_sum_ += input_processing_time;
1344       }
1345     }
1346   }
1347   return sum;
1348 }
1349 
SelfProcessingTimeLocked() const1350 double Node::SelfProcessingTimeLocked() const {
1351   if (num_elements_ == 0) {
1352     return 0;
1353   }
1354   return static_cast<double>(processing_time_) /
1355          static_cast<double>(num_elements_);
1356 }
1357 
CollectNodes(TraversalOrder order,bool collect_node (const std::shared_ptr<Node>)) const1358 Node::NodeVector Node::CollectNodes(
1359     TraversalOrder order, bool collect_node(const std::shared_ptr<Node>)) const
1360     TF_SHARED_LOCKS_REQUIRED(mu_) {
1361   NodeVector node_vector;
1362   std::list<std::shared_ptr<Node>> temp_list;
1363 
1364   for (auto& input : inputs_) {
1365     if (collect_node(input)) {
1366       node_vector.push_back(input);
1367       temp_list.push_back(input);
1368     }
1369   }
1370 
1371   while (!temp_list.empty()) {
1372     auto cur_node = temp_list.front();
1373     temp_list.pop_front();
1374     tf_shared_lock l(cur_node->mu_);
1375     for (auto& input : cur_node->inputs_) {
1376       if (collect_node(input)) {
1377         node_vector.push_back(input);
1378         temp_list.push_back(input);
1379       }
1380     }
1381   }
1382 
1383   if (order == TraversalOrder::REVERSE_BFS) {
1384     std::reverse(node_vector.begin(), node_vector.end());
1385   }
1386   return node_vector;
1387 }
1388 
CollectTunableParametersHelper(Node::ModelParameters * parameters) const1389 void Node::CollectTunableParametersHelper(
1390     Node::ModelParameters* parameters) const TF_SHARED_LOCKS_REQUIRED(mu_) {
1391   // If autotune is turned off or there are no elements recorded, we don't
1392   // collect the parameters on the node.
1393   if (!autotune_ || num_elements_ <= 0) {
1394     return;
1395   }
1396   for (auto& pair : parameters_) {
1397     if (pair.second->state->tunable) {
1398       parameters->push_back(std::make_pair(long_name(), pair.second));
1399     }
1400   }
1401 }
1402 
DebugStringHelper(absl::flat_hash_map<string,string> * debug_strings) const1403 void Node::DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings)
1404     const TF_SHARED_LOCKS_REQUIRED(mu_) {
1405   string result;
1406   strings::StrAppend(&result, long_name(), ":\n");
1407   strings::StrAppend(&result, "  autotune=", autotune_.load(), "\n");
1408   strings::StrAppend(&result, "  buffered_bytes=", buffered_bytes_.load(),
1409                      "\n");
1410   strings::StrAppend(&result, "  buffered_elements=", buffered_elements_.load(),
1411                      "\n");
1412   strings::StrAppend(&result, "  bytes_consumed=", bytes_consumed_.load(),
1413                      "\n");
1414   strings::StrAppend(&result, "  bytes_produced=", bytes_produced_.load(),
1415                      "\n");
1416   strings::StrAppend(&result, "  processing_time=", processing_time_.load(),
1417                      "\n");
1418   strings::StrAppend(&result, "  num_elements=", num_elements_.load(), "\n");
1419   string inputs;
1420   for (auto& input : inputs_) {
1421     strings::StrAppend(&inputs, input->long_name(), ",");
1422   }
1423   strings::StrAppend(&result, "  inputs={", inputs, "}\n");
1424   for (auto& input : inputs_) {
1425     strings::StrAppend(&result, debug_strings->at(input->long_name()));
1426   }
1427   debug_strings->insert(std::make_pair(long_name(), result));
1428 }
1429 
SnapshotHelper(std::shared_ptr<Node> cloned_output,Node::NodePairList * node_pairs) const1430 std::shared_ptr<Node> Node::SnapshotHelper(
1431     std::shared_ptr<Node> cloned_output, Node::NodePairList* node_pairs) const {
1432   tf_shared_lock l(mu_);
1433 
1434   // Clone current node(`this`), also set clone of its output node
1435   // (`cloned_output`) to be the output node of the cloned node
1436   // (`cloned_current`).
1437   std::shared_ptr<Node> cloned_current = Clone(cloned_output);
1438   {
1439     cloned_current->autotune_.store(autotune_);
1440     cloned_current->buffered_bytes_.store(buffered_bytes_);
1441     cloned_current->buffered_elements_.store(buffered_elements_);
1442     cloned_current->bytes_consumed_.store(bytes_consumed_);
1443     cloned_current->bytes_produced_.store(bytes_produced_);
1444     cloned_current->num_elements_.store(num_elements_);
1445     cloned_current->record_metrics_.store(false);
1446     cloned_current->processing_time_.store(processing_time_);
1447     mutex_lock l2(cloned_current->mu_);
1448     cloned_current->parameters_ = parameters_;
1449   }
1450 
1451   for (auto& input : inputs_) {
1452     node_pairs->push_back(std::make_pair(input, cloned_current));
1453   }
1454   return cloned_current;
1455 }
1456 
TotalBufferedBytesHelper(Node::NodeValues * total_bytes) const1457 void Node::TotalBufferedBytesHelper(Node::NodeValues* total_bytes) const
1458     TF_SHARED_LOCKS_REQUIRED(mu_) {
1459   if (!autotune_) {
1460     total_bytes->insert(std::make_pair(long_name(), 0));
1461     return;
1462   }
1463 
1464   double result = 0;
1465   auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
1466   if (!parameter) {
1467     parameter = gtl::FindOrNull(parameters_, kParallelism);
1468   }
1469   if (parameter) {
1470     result = buffered_bytes_;
1471   }
1472   for (auto& input : inputs_) {
1473     result += total_bytes->at(input->long_name());
1474   }
1475   total_bytes->insert(std::make_pair(long_name(), result));
1476 }
1477 
TotalMaximumBufferedBytesHelper(Node::NodeValues * total_bytes) const1478 void Node::TotalMaximumBufferedBytesHelper(Node::NodeValues* total_bytes) const
1479     TF_SHARED_LOCKS_REQUIRED(mu_) {
1480   if (!autotune_) {
1481     total_bytes->insert(std::make_pair(long_name(), 0));
1482     return;
1483   }
1484 
1485   double result = MaximumBufferedBytes();
1486   for (auto& input : inputs_) {
1487     result += total_bytes->at(input->long_name());
1488   }
1489   total_bytes->insert(std::make_pair(long_name(), result));
1490 }
1491 
MaximumBufferedBytes() const1492 double Node::MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
1493   return 0;
1494 }
1495 
ToProto(ModelProto::Node * node_proto) const1496 Status Node::ToProto(ModelProto::Node* node_proto) const {
1497   tf_shared_lock l(mu_);
1498   node_proto->set_id(id_);
1499   node_proto->set_name(name_);
1500   node_proto->set_autotune(autotune_);
1501   node_proto->set_buffered_bytes(buffered_bytes_);
1502   node_proto->set_buffered_elements(buffered_elements_);
1503   node_proto->set_bytes_consumed(bytes_consumed_);
1504   node_proto->set_bytes_produced(bytes_produced_);
1505   node_proto->set_num_elements(num_elements_);
1506   node_proto->set_processing_time(processing_time_);
1507   node_proto->set_record_metrics(record_metrics_);
1508 
1509   // Produce protos for all parameters.
1510   for (auto const& parameter : parameters_) {
1511     ModelProto::Node::Parameter* parameter_proto = node_proto->add_parameters();
1512     parameter_proto->set_name(parameter.first);
1513     parameter_proto->set_value(parameter.second->value);
1514     parameter_proto->set_min(parameter.second->min);
1515     parameter_proto->set_max(parameter.second->max);
1516     parameter_proto->set_state_value(parameter.second->state->value);
1517     parameter_proto->set_tunable(parameter.second->state->tunable);
1518   }
1519 
1520   // Add input node ids.
1521   for (auto const& input : inputs_) {
1522     node_proto->add_inputs(input->id());
1523   }
1524   return Status::OK();
1525 }
1526 
FromProtoHelper(ModelProto::Node node_proto,std::shared_ptr<Node> node)1527 Status Node::FromProtoHelper(ModelProto::Node node_proto,
1528                              std::shared_ptr<Node> node) {
1529   tf_shared_lock l(node->mu_);
1530   node->autotune_.store(node_proto.autotune());
1531   node->buffered_bytes_.store(node_proto.buffered_bytes());
1532   node->buffered_elements_.store(node_proto.buffered_elements());
1533   node->bytes_consumed_.store(node_proto.bytes_consumed());
1534   node->bytes_produced_.store(node_proto.bytes_produced());
1535   node->num_elements_.store(node_proto.num_elements());
1536   node->processing_time_.store(node_proto.processing_time());
1537   node->record_metrics_.store(node_proto.record_metrics());
1538 
1539   // Restore parameters.
1540   int64_t num_parameters = node_proto.parameters_size();
1541   for (int i = 0; i < num_parameters; i++) {
1542     const ModelProto::Node::Parameter& parameter_proto =
1543         node_proto.parameters(i);
1544     std::shared_ptr<SharedState> state;
1545     if (parameter_proto.tunable()) {
1546       state =
1547           std::make_shared<SharedState>(kAutotune, std::make_shared<mutex>(),
1548                                         std::make_shared<condition_variable>());
1549       state->value = parameter_proto.state_value();
1550     } else {
1551       state = std::make_shared<SharedState>(
1552           parameter_proto.state_value(), std::make_shared<mutex>(),
1553           std::make_shared<condition_variable>());
1554     }
1555     node->parameters_[parameter_proto.name()] =
1556         MakeParameter(parameter_proto.name(), state, parameter_proto.min(),
1557                       parameter_proto.max());
1558   }
1559   return Status::OK();
1560 }
1561 
FromProto(ModelProto::Node node_proto,std::shared_ptr<Node> output,std::shared_ptr<Node> * node)1562 Status Node::FromProto(ModelProto::Node node_proto,
1563                        std::shared_ptr<Node> output,
1564                        std::shared_ptr<Node>* node) {
1565   // Note that parameters are restored in `FromProtoHelper`.
1566   Args args = {node_proto.id(), node_proto.name(), std::move(output)};
1567   switch (node_proto.node_class()) {
1568     case NodeClass::INTERLEAVE_MANY:
1569       *node = std::make_shared<InterleaveMany>(args);
1570       break;
1571     case NodeClass::ASYNC_INTERLEAVE_MANY:
1572       *node = std::make_shared<AsyncInterleaveMany>(
1573           args, /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
1574       break;
1575     case NodeClass::KNOWN_RATIO:
1576       *node = std::make_shared<KnownRatio>(args, node_proto.ratio());
1577       break;
1578     case NodeClass::ASYNC_KNOWN_RATIO:
1579       *node = std::make_shared<AsyncKnownRatio>(
1580           args, node_proto.ratio(), node_proto.memory_ratio(),
1581           /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
1582       break;
1583     case NodeClass::UNKNOWN_RATIO:
1584       *node = std::make_shared<UnknownRatio>(args);
1585       break;
1586     default:
1587       *node = std::make_shared<Unknown>(args);
1588   }
1589   return FromProtoHelper(node_proto, *node);
1590 }
1591 
Model()1592 Model::Model()
1593     : collect_resource_usage_(false),
1594       optimization_period_ms_(kOptimizationPeriodMinMs) {
1595   model_gauge_cell_ = metrics::GetTFDataModelGauge(
1596       strings::StrCat(reinterpret_cast<uint64>(this)));
1597   model_gauge_cell_->Set([&]() { return DebugString(); });
1598 }
1599 
~Model()1600 Model::~Model() {
1601   // Before the model is destroyed, we record its final state in the gauge.
1602   auto result = DebugString();
1603   model_gauge_cell_->Set([result]() { return result; });
1604 }
1605 
AddNode(Node::Factory factory,const string & name,std::shared_ptr<Node> parent,std::shared_ptr<Node> * out_node)1606 void Model::AddNode(Node::Factory factory, const string& name,
1607                     std::shared_ptr<Node> parent,
1608                     std::shared_ptr<Node>* out_node) {
1609   // The name captures the sequence of iterators joined by `::`. We only use the
1610   // last element of the sequence as the name node.
1611   auto node_name = str_util::Split(name, ':', str_util::SkipEmpty()).back();
1612   mutex_lock l(mu_);
1613   std::shared_ptr<Node> node = factory({id_counter_++, node_name, parent});
1614   if (!output_) {
1615     output_ = node;
1616   }
1617   if (parent) {
1618     VLOG(3) << "Adding " << node->long_name() << " as input for "
1619             << parent->long_name();
1620     parent->add_input(node);
1621   } else {
1622     VLOG(3) << "Adding " << node->long_name();
1623   }
1624   collect_resource_usage_ =
1625       collect_resource_usage_ || node->has_tunable_parameters();
1626   *out_node = std::move(node);
1627   // TODO(jsimsa): Reset the optimization period when a node is added so that
1628   // autotuning adapts to changes to the input pipeline faster. Initial attempt
1629   // to enable this functionality caused a regression (see b/179812091).
1630 }
1631 
FlushMetrics()1632 void Model::FlushMetrics() {
1633   std::deque<std::shared_ptr<Node>> queue;
1634   {
1635     tf_shared_lock l(mu_);
1636     if (output_) queue.push_back(output_);
1637   }
1638   while (!queue.empty()) {
1639     auto node = queue.front();
1640     queue.pop_front();
1641     node->FlushMetrics();
1642     for (auto input : node->inputs()) {
1643       queue.push_back(input);
1644     }
1645   }
1646 }
1647 
Optimize(AutotuneAlgorithm algorithm,int64_t cpu_budget,int64_t ram_budget,double model_input_time,CancellationManager * cancellation_manager)1648 void Model::Optimize(AutotuneAlgorithm algorithm, int64_t cpu_budget,
1649                      int64_t ram_budget, double model_input_time,
1650                      CancellationManager* cancellation_manager) {
1651   std::shared_ptr<Node> snapshot;
1652   {
1653     tf_shared_lock l(mu_);
1654     snapshot = output_->Snapshot();
1655   }
1656   OptimizationParams optimization_params;
1657   optimization_params.set_algorithm(algorithm);
1658   optimization_params.set_cpu_budget(cpu_budget);
1659   optimization_params.set_ram_budget(ram_budget);
1660   optimization_params.set_model_input_time(model_input_time);
1661   switch (algorithm) {
1662     case AutotuneAlgorithm::HILL_CLIMB:
1663       OptimizeHillClimb(snapshot, optimization_params, cancellation_manager);
1664       break;
1665     case AutotuneAlgorithm::GRADIENT_DESCENT:
1666       OptimizeGradientDescent(snapshot, optimization_params,
1667                               cancellation_manager);
1668       break;
1669     default:
1670       VLOG(2) << "Autotuning algorithm was not recognized. Aborting "
1671                  "optimization.";
1672       return;
1673   }
1674 }
1675 
RemoveNode(std::shared_ptr<Node> node)1676 void Model::RemoveNode(std::shared_ptr<Node> node) {
1677   mutex_lock l(mu_);
1678   if (node) {
1679     if (node->output()) {
1680       node->output()->remove_input(node);
1681     }
1682     VLOG(3) << "Removing " << node->long_name();
1683   }
1684 }
1685 
CollectTunableParameters(std::shared_ptr<Node> node)1686 Model::ModelParameters Model::CollectTunableParameters(
1687     std::shared_ptr<Node> node) {
1688   return node->CollectTunableParameters();
1689 }
1690 
ShouldStop(int64_t cpu_budget,int64_t ram_budget,const Model::ModelParameters & parameters,const Model::ModelParameters & parallelism_parameters,const Model::ModelParameters & buffer_size_parameters,std::shared_ptr<Node> snapshot,bool * cpu_budget_reached)1691 bool Model::ShouldStop(int64_t cpu_budget, int64_t ram_budget,
1692                        const Model::ModelParameters& parameters,
1693                        const Model::ModelParameters& parallelism_parameters,
1694                        const Model::ModelParameters& buffer_size_parameters,
1695                        std::shared_ptr<Node> snapshot,
1696                        bool* cpu_budget_reached) {
1697   if (!(*cpu_budget_reached)) {
1698     // If those essential transformations' parallelism reaches the CPU
1699     // budget, we will only tune the buffer size parameters in future
1700     // iterations.
1701     int64_t model_parallelism = 0;
1702     for (auto& pair : parallelism_parameters) {
1703       model_parallelism += std::round(pair.second->value);
1704     }
1705     *cpu_budget_reached = (model_parallelism > cpu_budget);
1706   }
1707 
1708   bool all_max = true;
1709   for (auto& pair :
1710        (*cpu_budget_reached ? buffer_size_parameters : parameters)) {
1711     if (std::round(pair.second->value) < pair.second->max) {
1712       all_max = false;
1713       break;
1714     }
1715   }
1716 
1717   // If all parameters have reached their maximum values or RAM budget is
1718   // reached, we stop the iterations.
1719   return all_max || TotalMaximumBufferedBytes(snapshot) > ram_budget;
1720 }
1721 
1722 // TODO(jsimsa): Add support for tracking and using the model input time.
OptimizeLoop(AutotuneAlgorithm algorithm,int64_t cpu_budget,int64_t ram_budget,CancellationManager * cancellation_manager)1723 Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, int64_t cpu_budget,
1724                            int64_t ram_budget,
1725                            CancellationManager* cancellation_manager) {
1726   std::function<void()> unused;
1727   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
1728       cancellation_manager,
1729       [this]() {
1730         mutex_lock l(mu_);
1731         optimize_cond_var_.notify_all();
1732       },
1733       /*deregister_fn=*/&unused));
1734 
1735   int64_t last_optimization_ms = 0;
1736   int64_t current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1737   while (true) {
1738     {
1739       mutex_lock l(mu_);
1740       while (!cancellation_manager->IsCancelled() &&
1741              last_optimization_ms + optimization_period_ms_ > current_time_ms) {
1742         auto wait_ms =
1743             last_optimization_ms + optimization_period_ms_ - current_time_ms;
1744         VLOG(2) << "Waiting for " << wait_ms << " ms.";
1745         optimize_cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
1746         current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1747       }
1748       if (cancellation_manager->IsCancelled()) {
1749         return Status::OK();
1750       }
1751     }
1752 
1753     int64_t start_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1754     Optimize(algorithm, cpu_budget, ram_budget, /*model_input_time=*/0,
1755              cancellation_manager);
1756     int64_t end_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1757     VLOG(2) << "Optimized for " << end_ms - start_ms << " ms.";
1758 
1759     // Exponentially increase the period of running the optimization
1760     // until a threshold is reached.
1761     {
1762       mutex_lock l(mu_);
1763       optimization_period_ms_ =
1764           std::min(optimization_period_ms_ << 1, kOptimizationPeriodMaxMs);
1765     }
1766     current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1767     last_optimization_ms = current_time_ms;
1768     FlushMetrics();
1769   }
1770 }
1771 
OptimizeGradientDescent(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params,CancellationManager * cancellation_manager)1772 void Model::OptimizeGradientDescent(
1773     std::shared_ptr<Node> snapshot,
1774     const OptimizationParams& optimization_params,
1775     CancellationManager* cancellation_manager) {
1776   VLOG(2) << "Starting optimization of tunable parameters with Gradient "
1777              "Descent.";
1778   auto parameters = CollectTunableParameters(snapshot);
1779   if (parameters.empty()) {
1780     VLOG(2) << "The Gradient Descent optimization is terminated since no node "
1781                "with tunable parameters has recorded elements.";
1782     return;
1783   }
1784   VLOG(2) << "Number of tunable parameters: " << parameters.size();
1785 
1786   // The vectors of "essential" parallelism parameters and buffer size
1787   // parameters.
1788   Model::ModelParameters parallelism_parameters, buffer_size_parameters;
1789   CollectParameters(snapshot, parameters, &parallelism_parameters,
1790                     &buffer_size_parameters);
1791 
1792   // Initialize the parameter values to minimal before tuning.
1793   for (auto& pair : parameters) {
1794     pair.second->value = pair.second->min;
1795   }
1796 
1797   // Optimization is stopped once the `OutputTime` improvement is smaller than
1798   // this value.
1799   constexpr double kOptimizationPrecision = 100.0L;
1800 
1801   // Maximum number of iterations for optimization.
1802   constexpr int64_t kMaxIterations = 1000;
1803 
1804   double output_time = 0;
1805   double new_output_time;
1806 
1807   // When the CPU budget is reached, the parallelism parameter values are fixed
1808   // and we only increase the buffer size parameters.
1809   bool cpu_budget_reached = false;
1810 
1811   for (int i = 0; i < kMaxIterations; ++i) {
1812     if (cancellation_manager->IsCancelled() ||
1813         ShouldStop(optimization_params.cpu_budget(),
1814                    optimization_params.ram_budget(), parameters,
1815                    parallelism_parameters, buffer_size_parameters, snapshot,
1816                    &cpu_budget_reached)) {
1817       break;
1818     }
1819     Model::ParameterGradients gradients;
1820     new_output_time = OutputTime(
1821         snapshot, optimization_params.model_input_time(), &gradients);
1822     // We also terminate once the improvement of the output latency is too
1823     // small.
1824     if (std::abs(output_time - new_output_time) < kOptimizationPrecision) {
1825       break;
1826     }
1827 
1828     UpdateParameterValues(
1829         gradients, &(cpu_budget_reached ? buffer_size_parameters : parameters));
1830     output_time = new_output_time;
1831   }
1832 
1833   for (auto& pair : parameters) {
1834     pair.second->value = std::round(pair.second->value);
1835   }
1836   UpdateStateValues(&parameters);
1837 }
1838 
OptimizeHillClimb(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params,CancellationManager * cancellation_manager)1839 void Model::OptimizeHillClimb(std::shared_ptr<Node> snapshot,
1840                               const OptimizationParams& optimization_params,
1841                               CancellationManager* cancellation_manager) {
1842   VLOG(2) << "Starting optimization of tunable parameters with Hill Climb.";
1843   const double processing_time = TotalProcessingTime(snapshot);
1844   auto parameters = CollectTunableParameters(snapshot);
1845   if (parameters.empty()) {
1846     VLOG(2) << "The Hill Climb optimization is terminated since no node with "
1847                "tunable parameters has recorded elements.";
1848     return;
1849   }
1850   VLOG(2) << "Number of tunable parameters: " << parameters.size();
1851 
1852   // Buffer size parameter will only be incremented if the output latency
1853   // improvement is greater than this constant.
1854   constexpr double kBufferSizeMinDelta = 1.0L;
1855 
1856   // Initialize the parameter values to minimal before tuning.
1857   for (auto& pair : parameters) {
1858     pair.second->value = pair.second->min;
1859   }
1860   while (!cancellation_manager->IsCancelled()) {
1861     const double output_time =
1862         OutputTime(snapshot, optimization_params.model_input_time(),
1863                    /*gradients=*/nullptr);
1864     bool all_max = true;
1865     for (auto& pair : parameters) {
1866       if (pair.second->value < pair.second->max) {
1867         all_max = false;
1868         break;
1869       }
1870     }
1871     if (output_time < processing_time / optimization_params.cpu_budget() ||
1872         all_max ||
1873         TotalMaximumBufferedBytes(snapshot) >
1874             optimization_params.ram_budget()) {
1875       break;
1876     }
1877     double best_delta = -1.0L;
1878     Parameter* best_parameter = nullptr;
1879     for (auto& pair : parameters) {
1880       if (pair.second->value >= pair.second->max) {
1881         continue;
1882       }
1883       pair.second->value++;
1884       double new_output_time =
1885           OutputTime(snapshot, optimization_params.model_input_time(),
1886                      /*gradients=*/nullptr);
1887       double delta = output_time - new_output_time;
1888       if (delta > best_delta &&
1889           (delta > kBufferSizeMinDelta || pair.second->name != kBufferSize)) {
1890         best_delta = delta;
1891         best_parameter = pair.second.get();
1892       }
1893       pair.second->value--;
1894     }
1895     if (!best_parameter) {
1896       VLOG(2) << "Failed to find a tunable parameter that would further "
1897                  "decrease the output time. This means that the autotuning "
1898                  "optimization got stuck in a local maximum. The optimization "
1899                  "attempt will terminate early.";
1900       break;
1901     }
1902     best_parameter->value++;
1903   }
1904   UpdateStateValues(&parameters);
1905 }
1906 
OutputTime(std::shared_ptr<Node> node,double model_input_time,Model::ParameterGradients * gradients)1907 double Model::OutputTime(std::shared_ptr<Node> node, double model_input_time,
1908                          Model::ParameterGradients* gradients) {
1909   // To store the input time for each node.
1910   Model::NodeValues input_times = {{kModelInputTimeKey, model_input_time}};
1911 
1912   // TODO(jsimsa): Now that we are accounting for buffer size in wait time
1913   // computation, assuming that the input is infinitely fast will result in
1914   // inaccurate estimates of the output latency.
1915   //
1916   // We should compute the output latency as a fix-point of the following
1917   // equation: `output_time = node(OutputTime(input_times(1, output_time))`.
1918 
1919   return node->OutputTime(&input_times, gradients);
1920 }
1921 
TotalBufferedBytes(std::shared_ptr<Node> node)1922 double Model::TotalBufferedBytes(std::shared_ptr<Node> node) {
1923   return node->TotalBufferedBytes();
1924 }
1925 
TotalMaximumBufferedBytes(std::shared_ptr<Node> node)1926 double Model::TotalMaximumBufferedBytes(std::shared_ptr<Node> node) {
1927   return node->TotalMaximumBufferedBytes();
1928 }
1929 
TotalProcessingTime(std::shared_ptr<Node> node)1930 double Model::TotalProcessingTime(std::shared_ptr<Node> node) {
1931   return node->TotalProcessingTime(/*processing_times=*/nullptr);
1932 }
1933 
ToProto(ModelProto * model_proto)1934 Status Model::ToProto(ModelProto* model_proto) {
1935   tf_shared_lock l(mu_);
1936   model_proto->set_id_counter(id_counter_);
1937   model_proto->set_collect_resource_usage(collect_resource_usage_);
1938   return ModelToProtoHelper(output_, model_proto);
1939 }
1940 
FromProto(ModelProto model_proto,std::unique_ptr<Model> * model)1941 Status Model::FromProto(ModelProto model_proto, std::unique_ptr<Model>* model) {
1942   std::unique_ptr<Model> restored_model = std::make_unique<Model>();
1943   mutex_lock l(restored_model->mu_);
1944   TF_RETURN_IF_ERROR(
1945       ModelFromProtoHelper(model_proto, &restored_model->output_));
1946   restored_model->id_counter_ = model_proto.id_counter();
1947   restored_model->collect_resource_usage_.store(
1948       model_proto.collect_resource_usage());
1949   *model = std::move(restored_model);
1950   return Status::OK();
1951 }
1952 
Save(const string & fname,std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params)1953 Status Model::Save(const string& fname, std::shared_ptr<Node> snapshot,
1954                    const OptimizationParams& optimization_params) {
1955   ModelProto model_proto;
1956   std::unique_ptr<Model> model_snapshot = std::make_unique<Model>();
1957   {
1958     mutex_lock l(model_snapshot->mu_);
1959     model_snapshot->output_ = std::move(snapshot);
1960     model_snapshot->id_counter_ = id_counter_;
1961     model_snapshot->collect_resource_usage_.store(collect_resource_usage_);
1962   }
1963   TF_RETURN_IF_ERROR(model_snapshot->ToProto(&model_proto));
1964   OptimizationParams* saved_optimization_params =
1965       model_proto.mutable_optimization_params();
1966   *saved_optimization_params = optimization_params;
1967   return WriteBinaryProto(Env::Default(), fname, model_proto);
1968 }
1969 
Load(const string & fname,std::unique_ptr<Model> * model,OptimizationParams * optimization_params)1970 Status Model::Load(const string& fname, std::unique_ptr<Model>* model,
1971                    OptimizationParams* optimization_params) {
1972   ModelProto model_proto;
1973   TF_RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), fname, &model_proto));
1974   TF_RETURN_IF_ERROR(FromProto(model_proto, model));
1975   const OptimizationParams restored_optimization_params =
1976       model_proto.optimization_params();
1977   *optimization_params = restored_optimization_params;
1978   return Status::OK();
1979 }
1980 
DebugString()1981 std::string Model::DebugString() {
1982   constexpr int64_t kMinSecondsBetweenCalls = 30;
1983   if (absl::Now() < cache_until_) return cached_debug_string_;
1984   std::shared_ptr<Node> snapshot;
1985   {
1986     tf_shared_lock l(mu_);
1987     if (!output_) return cached_debug_string_;
1988     snapshot = output_->Snapshot();
1989   }
1990   // TODO(jsimsa): Populate OptimizationParams.
1991   ModelProto model_proto;
1992   Status s = ModelToProtoHelper(snapshot, &model_proto);
1993   if (s.ok()) {
1994     cached_debug_string_ = model_proto.DebugString();
1995   } else {
1996     LOG(WARNING) << s.error_message();
1997   }
1998   cache_until_ = absl::Now() + absl::Seconds(kMinSecondsBetweenCalls);
1999   return cached_debug_string_;
2000 }
2001 
2002 }  // namespace model
2003 }  // namespace data
2004 }  // namespace tensorflow
2005