1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/model.h"
17
18 #include <memory>
19
20 #include "absl/time/clock.h"
21 #include "tensorflow/core/framework/cancellation.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24
25 namespace tensorflow {
26 namespace data {
27 namespace model {
28
29 constexpr int64_t Model::kOptimizationPeriodMinMs;
30 constexpr int64_t Model::kOptimizationPeriodMaxMs;
31
32 namespace {
33
34 // Helper function for node traversal that doesn't skip any nodes.
IsAnyNode(const std::shared_ptr<Node> node)35 inline bool IsAnyNode(const std::shared_ptr<Node> node) { return true; }
36
37 // Helper function for node traversal that filters out nodes for which
38 // autotuning is disabled.
IsAutotuneNode(const std::shared_ptr<Node> node)39 inline bool IsAutotuneNode(const std::shared_ptr<Node> node) {
40 return node->autotune();
41 }
42
43 // Wrapper for the square function to reduce verbosity.
Square(double x)44 inline double Square(double x) { return x * x; }
45
46 // Collects "essential" parallelism parameters and buffer size parameters in the
47 // tree rooted in the given node. Which parallelism parameters are essential is
48 // determined by the relative processing time spent in the corresponding
49 // transformation. The collected parameters are returned via maps that map node
50 // names to their respective parameters.
CollectParameters(std::shared_ptr<Node> node,const Node::ModelParameters & parameters,Node::ModelParameters * parallelism_parameters,Node::ModelParameters * buffer_size_parameters)51 inline void CollectParameters(std::shared_ptr<Node> node,
52 const Node::ModelParameters& parameters,
53 Node::ModelParameters* parallelism_parameters,
54 Node::ModelParameters* buffer_size_parameters) {
55 // Parallelism parameter is considered to be essential if the corresponding
56 // transformations's processing time is greater than essential rate times the
57 // average transformation self processing time.
58 constexpr double kEssentialRate = 0.3L;
59
60 Node::NodeValues processing_times;
61 double processing_time = node->TotalProcessingTime(&processing_times);
62 double uniform_share =
63 processing_time / static_cast<double>(processing_times.size());
64 for (auto& pair : parameters) {
65 if (pair.second->name == kParallelism &&
66 processing_times[pair.first] > kEssentialRate * uniform_share) {
67 parallelism_parameters->push_back(pair);
68 } else if (pair.second->name == kBufferSize) {
69 buffer_size_parameters->push_back(pair);
70 }
71 }
72 }
73
74 // Applies the gradient descent method once and updates the parameter values. If
75 // the new value is out of the range, bound it within the range between the
76 // minimal and maximum values.
UpdateParameterValues(const Node::ParameterGradients & gradients,Node::ModelParameters * parameters)77 inline void UpdateParameterValues(const Node::ParameterGradients& gradients,
78 Node::ModelParameters* parameters) {
79 // Gradient descent step size.
80 constexpr double kDescentStep = 0.1L;
81 double new_value;
82
83 double max_abs_derivative = 1.0;
84 for (auto& pair : *parameters) {
85 if (std::round(pair.second->value) != pair.second->max) {
86 auto* gradient = gtl::FindOrNull(
87 gradients, std::make_pair(pair.first, pair.second->name));
88 if (gradient) {
89 max_abs_derivative = std::max(max_abs_derivative, std::abs(*gradient));
90 }
91 }
92 }
93 for (auto& pair : *parameters) {
94 auto* gradient = gtl::FindOrNull(
95 gradients, std::make_pair(pair.first, pair.second->name));
96 if (gradient) {
97 new_value =
98 pair.second->value - kDescentStep * (*gradient) / max_abs_derivative;
99 // Projection on a feasible interval.
100 if (new_value > pair.second->max) {
101 pair.second->value = pair.second->max;
102 } else if (new_value < pair.second->min) {
103 pair.second->value = pair.second->min;
104 } else {
105 pair.second->value = new_value;
106 }
107 }
108 }
109 }
110
111 // Copies the parameter values (which are for optimization tuning) and updates
112 // the state values (which are for the input pipeline to follow).
UpdateStateValues(Node::ModelParameters * parameters)113 inline void UpdateStateValues(Node::ModelParameters* parameters) {
114 for (auto& pair : *parameters) {
115 auto& parameter = pair.second;
116 VLOG(2) << "Setting tunable parameter " << pair.first
117 << ":: " << parameter->name << " to " << parameter->value;
118 mutex_lock l(*parameter->state->mu);
119 parameter->state->value = parameter->value;
120 parameter->state->cond_var->notify_all();
121 }
122 }
123
124 // Recursively produces protos for nodes in a subtree of `output` node and
125 // appends them to nodes of the given model.
ModelToProtoHelper(std::shared_ptr<Node> output,ModelProto * model)126 Status ModelToProtoHelper(std::shared_ptr<Node> output, ModelProto* model) {
127 model->set_output(output->id());
128 std::list<std::shared_ptr<Node>> to_serialize = {output};
129 auto& nodes = *model->mutable_nodes();
130 while (!to_serialize.empty()) {
131 const std::shared_ptr<Node> node = to_serialize.front();
132 to_serialize.pop_front();
133 TF_RETURN_IF_ERROR(node->ToProto(&(nodes[node->id()])));
134 for (auto input : node->inputs()) {
135 to_serialize.push_back(input);
136 }
137 }
138 return Status::OK();
139 }
140
141 // Recursively produces node tree rooted in `output` from the given model proto.
ModelFromProtoHelper(ModelProto model,std::shared_ptr<Node> * output)142 Status ModelFromProtoHelper(ModelProto model, std::shared_ptr<Node>* output) {
143 TF_RETURN_IF_ERROR(Node::FromProto(model.nodes().at(model.output()),
144 /*output=*/nullptr, output));
145 std::list<std::shared_ptr<Node>> to_restore_inputs = {*output};
146 while (!to_restore_inputs.empty()) {
147 std::shared_ptr<Node> node = to_restore_inputs.front();
148 to_restore_inputs.pop_front();
149 for (int64_t input_id : model.nodes().at(node->id()).inputs()) {
150 std::shared_ptr<Node> input;
151 TF_RETURN_IF_ERROR(
152 Node::FromProto(model.nodes().at(input_id), node, &input));
153 node->add_input(input);
154 to_restore_inputs.push_back(input);
155 }
156 }
157 return Status::OK();
158 }
159
160 // The first input of InterleaveMany corresponds to the input dataset whose
161 // elements are used to create the (derived) input datasets whose elements are
162 // interleaved as output.
163 //
164 // TODO(jsimsa): model the first input
165 class InterleaveMany : public Node {
166 public:
167 using Node::Node;
168
~InterleaveMany()169 virtual ~InterleaveMany() {}
170
171 protected:
Clone(std::shared_ptr<Node> output) const172 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
173 TF_SHARED_LOCKS_REQUIRED(mu_) {
174 return std::make_shared<InterleaveMany>(
175 Args{id_, name_, std::move(output)});
176 }
177
InputTimeLocked(NodeValues * input_times) const178 void InputTimeLocked(NodeValues* input_times) const override
179 TF_SHARED_LOCKS_REQUIRED(mu_) {
180 double inherited_input_time;
181 if (output_) {
182 inherited_input_time = (*input_times)[output_->long_name()];
183 } else {
184 inherited_input_time = (*input_times)[kModelInputTimeKey];
185 }
186
187 if (num_inputs() <= 1) {
188 (*input_times)[long_name()] = inherited_input_time;
189 return;
190 }
191 // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
192 // input time for InterleaveMany node to call one of the
193 // `(num_inputs() - 1)` input nodes (except first input) to return an
194 // element. Regardless of the `block_length` parameter of InterleaveMany
195 // node, the average input time for any of the `(num_inputs() - 1)` input
196 // nodes to be called is computed as:
197 double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
198 static_cast<double>(num_inputs() - 1);
199 (*input_times)[long_name()] = input_time;
200 }
201
202 // The output time is the sum of the self processing time and the average
203 // output time of inputs comprising the interleave "cycle".
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const204 void OutputTimeLocked(const NodeValues& input_times,
205 ParameterGradients* gradients, NodeValues* output_times,
206 NodeValues* output_time_gradients) const override
207 TF_SHARED_LOCKS_REQUIRED(mu_) {
208 double self_processing_time = SelfProcessingTimeLocked();
209 if (num_inputs() <= 1) {
210 (*output_times)[long_name()] = self_processing_time;
211 if (gradients) {
212 for (const auto& pair : CollectTunableParametersLocked()) {
213 gradients->erase(std::make_pair(pair.first, pair.second->name));
214 }
215 }
216 return;
217 }
218
219 double inputs_output_time =
220 (OutputTimeForInputs(*output_times) -
221 (*output_times)[inputs_.front()->long_name()]) /
222 static_cast<double>(num_inputs() - 1);
223 if (gradients) {
224 for (const auto& pair : CollectTunableParametersLocked()) {
225 auto* gradient = gtl::FindOrNull(
226 *gradients, std::make_pair(pair.first, pair.second->name));
227 if (gradient) {
228 *gradient /= static_cast<double>(num_inputs() - 1);
229 }
230 }
231
232 (*output_time_gradients)[long_name()] =
233 OutputTimeGradientsForInputs(*output_time_gradients) -
234 (*output_time_gradients)[inputs_.front()->long_name()];
235
236 // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
237 // first input equal to 0 since its output time is excluded from
238 // computations.
239 for (auto& pair : inputs_.front()->CollectTunableParameters()) {
240 (*gradients)[std::make_pair(pair.first, pair.second->name)] = 0.0L;
241 }
242 }
243 (*output_times)[long_name()] = self_processing_time + inputs_output_time;
244 }
245
246 // The processing time is the sum of the self processing time and the average
247 // processing time of inputs comprising the interleave "cycle".
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)248 void TotalProcessingTimeLocked(NodeValues* processing_times,
249 NodeValues* total_processing_times) override
250 TF_SHARED_LOCKS_REQUIRED(mu_) {
251 double self_processing_time = SelfProcessingTimeLocked();
252 if (processing_times) {
253 (*processing_times)[long_name()] = self_processing_time;
254 }
255 if (num_inputs() <= 1) {
256 (*total_processing_times)[long_name()] = self_processing_time;
257 return;
258 }
259 double inputs_processing_time =
260 (TotalProcessingTimeForInputs(*total_processing_times) -
261 (*total_processing_times)[inputs_.front()->long_name()]) /
262 static_cast<double>(num_inputs() - 1);
263 (*total_processing_times)[long_name()] =
264 self_processing_time + inputs_processing_time;
265 }
266
ToProto(ModelProto::Node * node_proto) const267 Status ToProto(ModelProto::Node* node_proto) const {
268 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
269 node_proto->set_node_class(NodeClass::INTERLEAVE_MANY);
270 return Status::OK();
271 }
272 };
273
274 // The first input of AsyncInterleaveMany corresponds to the input dataset whose
275 // elements are used to create the (derived) input datasets whose elements are
276 // interleaved as output.
277 //
278 // TODO(jsimsa): model the first input
279 class AsyncInterleaveMany : public Node {
280 public:
AsyncInterleaveMany(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)281 AsyncInterleaveMany(Node::Args args,
282 std::vector<std::shared_ptr<Parameter>> parameters)
283 : Node(args) {
284 for (auto& parameter : parameters) {
285 parameters_[parameter->name] = std::move(parameter);
286 }
287 }
288
~AsyncInterleaveMany()289 virtual ~AsyncInterleaveMany() {}
290
291 protected:
Clone(std::shared_ptr<Node> output) const292 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
293 TF_SHARED_LOCKS_REQUIRED(mu_) {
294 std::vector<std::shared_ptr<Parameter>> parameters;
295 for (auto& pair : parameters_) {
296 parameters.push_back(pair.second);
297 }
298 return std::make_shared<AsyncInterleaveMany>(
299 Args{id_, name_, std::move(output)}, parameters);
300 }
301
InputTimeLocked(NodeValues * input_times) const302 void InputTimeLocked(NodeValues* input_times) const override
303 TF_SHARED_LOCKS_REQUIRED(mu_) {
304 double inherited_input_time;
305 if (output_) {
306 inherited_input_time = (*input_times)[output_->long_name()];
307 } else {
308 inherited_input_time = (*input_times)[kModelInputTimeKey];
309 }
310
311 if (num_inputs() <= 1) {
312 (*input_times)[long_name()] = inherited_input_time;
313 return;
314 }
315 // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
316 // input time for AsyncInterleaveMany node to call one of the
317 // `(num_inputs() - 1)` input nodes (except first input) to return an
318 // element. Regardless of the `block_length` parameter of
319 // AsyncInterleaveMany node, the average input time for any of the
320 // `(num_inputs() - 1)` input nodes to be called is computed as:
321 double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
322 static_cast<double>(num_inputs() - 1);
323 (*input_times)[long_name()] = input_time;
324 }
325
326 // The output time is the sum of self processing time and expected wait time
327 // from the buffer model estimated using
328 // `ComputeWaitTime(producer_time, consumer_time, parallelism, ...)`, where
329 // `producer_time` is the average output time of inputs comprising the
330 // interleave "cycle" divided by `parallelism`, `consumer_time` is the
331 // `input_time` specified through `input_times` divided by `num_inputs() - 1`,
332 // and if the node has parallelism parameter, then `buffer_size` is derived
333 // from `parallelism`.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const334 void OutputTimeLocked(const NodeValues& input_times,
335 ParameterGradients* gradients, NodeValues* output_times,
336 NodeValues* output_time_gradients) const override
337 TF_SHARED_LOCKS_REQUIRED(mu_) {
338 double self_processing_time = SelfProcessingTimeLocked();
339 if (num_inputs() <= 1) {
340 (*output_times)[long_name()] = self_processing_time;
341 if (gradients) {
342 for (const auto& pair : CollectTunableParametersLocked()) {
343 gradients->erase(std::make_pair(pair.first, pair.second->name));
344 }
345 }
346 return;
347 }
348
349 double output_time, wait_time, consumer_time, producer_time;
350 double input_time = input_times.at(long_name());
351 consumer_time = input_time / static_cast<double>(num_inputs() - 1);
352 double parallelism = num_inputs() - 1; // default to cycle length
353 auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
354 if (parameter) {
355 parallelism = std::min(parallelism, (*parameter)->value);
356 }
357 double output_time_for_inputs =
358 OutputTimeForInputs(*output_times) -
359 (*output_times)[inputs_.front()->long_name()];
360 producer_time = output_time_for_inputs /
361 static_cast<double>(num_inputs() - 1) / parallelism;
362
363 if (gradients) {
364 double producer_time_der = 0.0L;
365 double consumer_time_der = 0.0L;
366 double buffer_size_der = 0.0L;
367 wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
368 &producer_time_der, &consumer_time_der,
369 &buffer_size_der);
370 double inputs_time_der_sum =
371 OutputTimeGradientsForInputs(*output_time_gradients);
372 (*output_time_gradients)[long_name()] =
373 consumer_time_der +
374 producer_time_der * inputs_time_der_sum / parallelism;
375
376 for (const auto& pair : CollectTunableParametersLocked()) {
377 auto* gradient = gtl::FindOrNull(
378 *gradients, std::make_pair(pair.first, pair.second->name));
379 if (gradient) {
380 *gradient *= (producer_time_der /
381 static_cast<double>(num_inputs() - 1) / parallelism);
382 }
383 }
384
385 // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
386 // first input equal to 0 since its output time is excluded from
387 // computations.
388 for (auto& pair : inputs_.front()->CollectTunableParameters()) {
389 (*gradients)[std::make_pair(pair.first, pair.second->name)] = 0.0L;
390 }
391 // Add derivative w.r.t. own parallelism parameter.
392 if (parameter && (*parameter)->state->tunable) {
393 (*gradients)[std::make_pair(long_name(), (*parameter)->name)] =
394 buffer_size_der - producer_time_der * producer_time / parallelism;
395 }
396 } else {
397 wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
398 /*producer_time_derivative=*/nullptr,
399 /*consumer_time_derivative=*/nullptr,
400 /*buffer_size_derivative=*/nullptr);
401 }
402 output_time = self_processing_time + wait_time;
403 (*output_times)[long_name()] = output_time;
404 }
405
406 // The processing time is the sum of the self processing time and the average
407 // processing time of inputs comprising the interleave "cycle".
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)408 void TotalProcessingTimeLocked(NodeValues* processing_times,
409 NodeValues* total_processing_times) override
410 TF_SHARED_LOCKS_REQUIRED(mu_) {
411 double self_processing_time = SelfProcessingTimeLocked();
412 if (processing_times) {
413 (*processing_times)[long_name()] = self_processing_time;
414 }
415 if (num_inputs() <= 1) {
416 (*total_processing_times)[long_name()] = self_processing_time;
417 return;
418 }
419 double inputs_processing_time =
420 (TotalProcessingTimeForInputs(*total_processing_times) -
421 (*total_processing_times)[inputs_.front()->long_name()]) /
422 static_cast<double>(num_inputs() - 1);
423 (*total_processing_times)[long_name()] =
424 self_processing_time + inputs_processing_time;
425 }
426
MaximumBufferedBytes() const427 double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
428 double result = 0;
429 auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
430 if (parameter) {
431 result += (*parameter)->value * AverageBufferedElementSize();
432 }
433 return result;
434 }
435
ToProto(ModelProto::Node * node_proto) const436 Status ToProto(ModelProto::Node* node_proto) const {
437 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
438 node_proto->set_node_class(NodeClass::ASYNC_INTERLEAVE_MANY);
439 return Status::OK();
440 }
441 };
442
443 class KnownRatio : public Node {
444 public:
KnownRatio(Node::Args args,double ratio)445 KnownRatio(Node::Args args, double ratio) : Node(args), ratio_(ratio) {}
446
~KnownRatio()447 virtual ~KnownRatio() {}
448
449 protected:
Clone(std::shared_ptr<Node> output) const450 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
451 TF_SHARED_LOCKS_REQUIRED(mu_) {
452 return std::make_shared<KnownRatio>(Args{id_, name_, std::move(output)},
453 ratio_);
454 }
455
456 // The input time is the sum of inherited input time and self processing time,
457 // divided by `ratio_`.
InputTimeLocked(NodeValues * input_times) const458 void InputTimeLocked(NodeValues* input_times) const override
459 TF_SHARED_LOCKS_REQUIRED(mu_) {
460 double inherited_input_time;
461 if (output_) {
462 inherited_input_time = (*input_times)[output_->long_name()];
463 } else {
464 inherited_input_time = (*input_times)[kModelInputTimeKey];
465 }
466
467 if (ratio_ == 0) {
468 (*input_times)[long_name()] = inherited_input_time;
469 return;
470 }
471 double input_time =
472 (inherited_input_time + SelfProcessingTimeLocked()) / ratio_;
473 (*input_times)[long_name()] = input_time;
474 }
475
476 // The output time is the sum of the self processing time and the product of
477 // `ratio_` and the sum of output times of inputs.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const478 void OutputTimeLocked(const NodeValues& input_times,
479 ParameterGradients* gradients, NodeValues* output_times,
480 NodeValues* output_time_gradients) const override
481 TF_SHARED_LOCKS_REQUIRED(mu_) {
482 double self_processing_time = SelfProcessingTimeLocked();
483 if (ratio_ == 0) {
484 (*output_times)[long_name()] = self_processing_time;
485 if (gradients) {
486 for (const auto& pair : CollectTunableParametersLocked()) {
487 gradients->erase(std::make_pair(pair.first, pair.second->name));
488 }
489 }
490 return;
491 }
492 if (gradients) {
493 for (const auto& pair : CollectTunableParametersLocked()) {
494 auto* gradient = gtl::FindOrNull(
495 *gradients, std::make_pair(pair.first, pair.second->name));
496 if (gradient) {
497 *gradient *= ratio_;
498 }
499 }
500 (*output_time_gradients)[long_name()] =
501 OutputTimeGradientsForInputs(*output_time_gradients);
502 }
503 double inputs_output_time = ratio_ * OutputTimeForInputs(*output_times);
504 (*output_times)[long_name()] = self_processing_time + inputs_output_time;
505 }
506
507 // The processing time is the sum of the self processing time and the product
508 // of `ratio_` and the sum of processing times of inputs.
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)509 void TotalProcessingTimeLocked(NodeValues* processing_times,
510 NodeValues* total_processing_times) override
511 TF_SHARED_LOCKS_REQUIRED(mu_) {
512 double self_processing_time = SelfProcessingTimeLocked();
513 if (processing_times) {
514 (*processing_times)[long_name()] = self_processing_time;
515 }
516 if (ratio_ == 0) {
517 (*total_processing_times)[long_name()] = self_processing_time;
518 return;
519 }
520 double inputs_processing_time =
521 ratio_ * TotalProcessingTimeForInputs(*total_processing_times);
522 (*total_processing_times)[long_name()] =
523 self_processing_time + inputs_processing_time;
524 }
525
ToProto(ModelProto::Node * node_proto) const526 Status ToProto(ModelProto::Node* node_proto) const {
527 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
528 node_proto->set_node_class(NodeClass::KNOWN_RATIO);
529 node_proto->set_ratio(ratio_);
530 return Status::OK();
531 }
532
533 private:
534 const double ratio_;
535 };
536
537 class AsyncKnownRatio : public Node {
538 public:
AsyncKnownRatio(Node::Args args,double ratio,double memory_ratio,std::vector<std::shared_ptr<Parameter>> parameters)539 AsyncKnownRatio(Node::Args args, double ratio, double memory_ratio,
540 std::vector<std::shared_ptr<Parameter>> parameters)
541 : Node(args), ratio_(ratio), memory_ratio_(memory_ratio) {
542 for (auto& parameter : parameters) {
543 parameters_[parameter->name] = std::move(parameter);
544 }
545 }
546
~AsyncKnownRatio()547 virtual ~AsyncKnownRatio() {}
548
549 protected:
Clone(std::shared_ptr<Node> output) const550 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
551 TF_SHARED_LOCKS_REQUIRED(mu_) {
552 std::vector<std::shared_ptr<Parameter>> parameters;
553 for (auto& pair : parameters_) {
554 parameters.push_back(pair.second);
555 }
556 return std::make_shared<AsyncKnownRatio>(
557 Args{id_, name_, std::move(output)}, ratio_, memory_ratio_, parameters);
558 }
559
560 // The input time is the sum of inherited input time and parallelism adjusted
561 // self processing time, divided by `ratio_`.
InputTimeLocked(NodeValues * input_times) const562 void InputTimeLocked(NodeValues* input_times) const override
563 TF_SHARED_LOCKS_REQUIRED(mu_) {
564 double inherited_input_time;
565 if (output_) {
566 inherited_input_time = (*input_times)[output_->long_name()];
567 } else {
568 inherited_input_time = (*input_times)[kModelInputTimeKey];
569 }
570 double parallelism = 1.0;
571 auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
572 if (parallelism_parameter) {
573 parallelism = (*parallelism_parameter)->value;
574 }
575
576 if (ratio_ == 0.0) {
577 (*input_times)[long_name()] =
578 inherited_input_time + SelfProcessingTimeLocked() / parallelism;
579 return;
580 }
581 double input_time =
582 (inherited_input_time + SelfProcessingTimeLocked() / parallelism) /
583 ratio_;
584 (*input_times)[long_name()] = input_time;
585 }
586
587 // The output time is the sum of parallelism adjusted self processing time and
588 // expected wait time from the buffer model estimated using
589 // `ComputeWaitTime(producer_time, consumer_time, parallelism, ...)`, where
590 // `producer_time` is the product of `ratio_` and the sum of output times of
591 // inputs, `consumer_time` is the product of `ratio_` and the `input_time`
592 // specified through `input_times` (since for each element stored in the
593 // buffer, the inputs need to be called `ratio_` times), and if the node has
594 // parallelism parameter, then `buffer_size` is derived from `parallelism`.
595 //
596 // Current implementation assumes that there is at most 1 parameter per node.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const597 void OutputTimeLocked(const NodeValues& input_times,
598 ParameterGradients* gradients, NodeValues* output_times,
599 NodeValues* output_time_gradients) const override
600 TF_SHARED_LOCKS_REQUIRED(mu_) {
601 double parallelism = 1.0;
602 double buffer_size = 0.0;
603 auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
604 auto* buffer_size_parameter = gtl::FindOrNull(parameters_, kBufferSize);
605 if (parallelism_parameter) {
606 parallelism = (*parallelism_parameter)->value;
607 if (ratio_ == 0) {
608 buffer_size = parallelism;
609 } else {
610 // Currently, MapAndBatch is the only transformation creates
611 // AsyncKnownRatio nodes with ratio >= 1. For MapAndBatch, we create
612 // `parallelism` threads to apply the function on elements from input
613 // dataset, while one element in the buffer actually corresponds to
614 // `ratio_` elements from input dataset. So we adjust the `buffer_size`
615 // by dividing `ratio_`.
616 buffer_size = parallelism / ratio_;
617 }
618 } else if (buffer_size_parameter) {
619 buffer_size = (*buffer_size_parameter)->value;
620 }
621 double self_processing_time = SelfProcessingTimeLocked();
622 double output_time, wait_time, consumer_time, producer_time;
623 double input_time = input_times.at(long_name());
624
625 if (ratio_ == 0) {
626 consumer_time = input_time;
627 producer_time = 0.0L;
628 if (gradients) {
629 for (const auto& pair : CollectTunableParametersLocked()) {
630 gradients->erase(std::make_pair(pair.first, pair.second->name));
631 }
632
633 double producer_time_der = 0.0L;
634 double consumer_time_der = 0.0L;
635 double buffer_size_der = 0.0L;
636 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
637 &producer_time_der, &consumer_time_der,
638 &buffer_size_der);
639 (*output_time_gradients)[long_name()] = consumer_time_der;
640 if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
641 (*gradients)[std::make_pair(long_name(),
642 (*parallelism_parameter)->name)] =
643 -(1.0L + consumer_time_der) * self_processing_time /
644 Square(parallelism) +
645 buffer_size_der;
646 } else if (buffer_size_parameter &&
647 (*buffer_size_parameter)->state->tunable) {
648 (*gradients)[std::make_pair(
649 long_name(), (*buffer_size_parameter)->name)] = buffer_size_der;
650 }
651 } else {
652 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
653 /*producer_time_derivative=*/nullptr,
654 /*consumer_time_derivative=*/nullptr,
655 /*buffer_size_derivative=*/nullptr);
656 }
657 output_time = self_processing_time / parallelism + wait_time;
658 (*output_times)[long_name()] = output_time;
659 return;
660 }
661
662 consumer_time = input_time * ratio_;
663 producer_time = ratio_ * OutputTimeForInputs(*output_times);
664 if (gradients) {
665 double producer_time_der = 0.0L;
666 double consumer_time_der = 0.0L;
667 double buffer_size_der = 0.0L;
668 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
669 &producer_time_der, &consumer_time_der,
670 &buffer_size_der);
671 double inputs_time_der_sum =
672 OutputTimeGradientsForInputs(*output_time_gradients);
673 (*output_time_gradients)[long_name()] =
674 consumer_time_der + producer_time_der * inputs_time_der_sum;
675
676 for (const auto& pair : CollectTunableParametersLocked()) {
677 auto* gradient = gtl::FindOrNull(
678 *gradients, std::make_pair(pair.first, pair.second->name));
679 if (gradient) {
680 *gradient *= (ratio_ * producer_time_der);
681 }
682 }
683
684 // Add derivative w.r.t. own parameter if it's tunable.
685 if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
686 (*gradients)[std::make_pair(long_name(),
687 (*parallelism_parameter)->name)] =
688 buffer_size_der / ratio_ -
689 (1.0L + consumer_time_der +
690 producer_time_der * inputs_time_der_sum) *
691 self_processing_time / Square(parallelism);
692 } else if (buffer_size_parameter &&
693 (*buffer_size_parameter)->state->tunable) {
694 (*gradients)[std::make_pair(
695 long_name(), (*buffer_size_parameter)->name)] = buffer_size_der;
696 }
697 } else {
698 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
699 /*producer_time_derivative=*/nullptr,
700 /*consumer_time_derivative=*/nullptr,
701 /*buffer_size_derivative=*/nullptr);
702 }
703 output_time = self_processing_time / parallelism + wait_time;
704 (*output_times)[long_name()] = output_time;
705 }
706
707 // The processing time is the sum of the self processing time and the product
708 // of `ratio_` and the sum of processing times of inputs.
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)709 void TotalProcessingTimeLocked(NodeValues* processing_times,
710 NodeValues* total_processing_times) override
711 TF_SHARED_LOCKS_REQUIRED(mu_) {
712 double self_processing_time = SelfProcessingTimeLocked();
713 if (processing_times) {
714 (*processing_times)[long_name()] = self_processing_time;
715 }
716 if (ratio_ == 0) {
717 (*total_processing_times)[long_name()] = self_processing_time;
718 return;
719 }
720 double inputs_processing_time =
721 ratio_ * TotalProcessingTimeForInputs(*total_processing_times);
722 (*total_processing_times)[long_name()] =
723 self_processing_time + inputs_processing_time;
724 }
725
MaximumBufferedBytes() const726 double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
727 double result = 0;
728 auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
729 if (!parameter) {
730 parameter = gtl::FindOrNull(parameters_, kParallelism);
731 }
732
733 if (parameter) {
734 if (memory_ratio_ == 0) {
735 result += (*parameter)->value * AverageBufferedElementSize();
736 } else {
737 // The estimation is currently not accurate for MapAndBatchDataset for
738 // the maximum buffer size does not match `num_parallel_calls`
739 // parameter.
740 result +=
741 (*parameter)->value * AverageBufferedElementSize() / memory_ratio_;
742 }
743 }
744 return result;
745 }
746
ToProto(ModelProto::Node * node_proto) const747 Status ToProto(ModelProto::Node* node_proto) const {
748 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
749 node_proto->set_node_class(NodeClass::ASYNC_KNOWN_RATIO);
750 node_proto->set_ratio(ratio_);
751 node_proto->set_memory_ratio(memory_ratio_);
752 return Status::OK();
753 }
754
755 private:
756 // Identifies how many input elements need to be created to construct an
757 // element for the dataset.
758 //
759 // Currently the value is 1 for PrefetchDataset and ParallelMapDataset,
760 // batch_size for MapAndBatchDataset and ParallelBatchDataset.
761 const double ratio_;
762 // For parallelism nodes, identifies how many parallelism calls are introduced
763 // by one buffered element. The value is defined to correctly estimate RAM
764 // budget bound with given num_parallel_calls (or buffer_size) combined with
765 // the estimated average size of buffered elements.
766 const double memory_ratio_;
767 };
768
769 class UnknownRatio : public Node {
770 public:
771 using Node::Node;
772
~UnknownRatio()773 virtual ~UnknownRatio() {}
774
775 protected:
Clone(std::shared_ptr<Node> output) const776 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
777 TF_SHARED_LOCKS_REQUIRED(mu_) {
778 return std::make_shared<UnknownRatio>(Args{id_, name_, std::move(output)});
779 }
780
781 // The input time is the sum of inherited input time and self processing time,
782 // divided by the ratio estimate.
InputTimeLocked(NodeValues * input_times) const783 void InputTimeLocked(NodeValues* input_times) const override
784 TF_SHARED_LOCKS_REQUIRED(mu_) {
785 double inherited_input_time;
786 if (output_) {
787 inherited_input_time = (*input_times)[output_->long_name()];
788 } else {
789 inherited_input_time = (*input_times)[kModelInputTimeKey];
790 }
791
792 if (num_elements_ == 0 || inputs_.empty() ||
793 inputs_.front()->num_elements() == 0) {
794 (*input_times)[long_name()] = inherited_input_time;
795 return;
796 }
797 std::shared_ptr<Node> input = inputs_.front();
798 double ratio = static_cast<double>(input->num_elements()) /
799 static_cast<double>(num_elements_);
800 double input_time =
801 (inherited_input_time + SelfProcessingTimeLocked()) / ratio;
802 (*input_times)[long_name()] = input_time;
803 }
804
805 // The output time is the sum of the self processing time and the product of
806 // the ratio estimate and the sum of output times of inputs.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const807 void OutputTimeLocked(const NodeValues& input_times,
808 ParameterGradients* gradients, NodeValues* output_times,
809 NodeValues* output_time_gradients) const override
810 TF_SHARED_LOCKS_REQUIRED(mu_) {
811 double self_processing_time = SelfProcessingTimeLocked();
812 if (num_elements_ == 0 || inputs_.empty() ||
813 inputs_.front()->num_elements() == 0) {
814 (*output_times)[long_name()] = self_processing_time;
815 if (gradients) {
816 for (const auto& pair : CollectTunableParametersLocked()) {
817 gradients->erase(std::make_pair(pair.first, pair.second->name));
818 }
819 }
820 return;
821 }
822 // TODO(jsimsa): The current implementation assumes that the number of input
823 // elements consumed per output is the same across all inputs.
824 double ratio = static_cast<double>(inputs_.front()->num_elements()) /
825 static_cast<double>(num_elements_);
826 if (gradients) {
827 for (const auto& pair : CollectTunableParametersLocked()) {
828 auto* gradient = gtl::FindOrNull(
829 *gradients, std::make_pair(pair.first, pair.second->name));
830 if (gradient) {
831 *gradient *= ratio;
832 }
833 }
834 (*output_time_gradients)[long_name()] =
835 OutputTimeGradientsForInputs(*output_time_gradients);
836 }
837 double inputs_output_time = ratio * OutputTimeForInputs(*output_times);
838 (*output_times)[long_name()] = self_processing_time + inputs_output_time;
839 }
840
841 // The processing time is the sum of the self processing time and the product
842 // of the ratio estimate and the sum of processing times of inputs.
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)843 void TotalProcessingTimeLocked(
844 absl::flat_hash_map<string, double>* processing_times,
845 absl::flat_hash_map<string, double>* total_processing_times) override
846 TF_SHARED_LOCKS_REQUIRED(mu_) {
847 double self_processing_time = SelfProcessingTimeLocked();
848 if (processing_times) {
849 (*processing_times)[long_name()] = self_processing_time;
850 }
851 if (inputs_.empty() || num_elements_ == 0) {
852 (*total_processing_times)[long_name()] = self_processing_time;
853 return;
854 }
855 // TODO(jsimsa): The current implementation assumes that the number of input
856 // elements consumed per output is the same across all inputs.
857 std::shared_ptr<Node> input = inputs_.front();
858 double ratio = static_cast<double>(input->num_elements()) /
859 static_cast<double>(num_elements_);
860 double inputs_processing_time =
861 ratio * TotalProcessingTimeForInputs(*total_processing_times);
862 (*total_processing_times)[long_name()] =
863 self_processing_time + inputs_processing_time;
864 }
865
ToProto(ModelProto::Node * node_proto) const866 Status ToProto(ModelProto::Node* node_proto) const {
867 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
868 node_proto->set_node_class(NodeClass::UNKNOWN_RATIO);
869 return Status::OK();
870 }
871 };
872
873 class Unknown : public Node {
874 public:
875 using Node::Node;
876
~Unknown()877 virtual ~Unknown() {}
878
879 protected:
Clone(std::shared_ptr<Node> output) const880 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
881 TF_SHARED_LOCKS_REQUIRED(mu_) {
882 return std::make_shared<Unknown>(Args{id_, name_, std::move(output)});
883 }
884
885 // The input time is the inherited input time.
InputTimeLocked(NodeValues * input_times) const886 void InputTimeLocked(NodeValues* input_times) const override
887 TF_SHARED_LOCKS_REQUIRED(mu_) {
888 double inherited_input_time;
889 if (output_) {
890 inherited_input_time = (*input_times)[output_->long_name()];
891 } else {
892 inherited_input_time = (*input_times)[kModelInputTimeKey];
893 }
894 (*input_times)[long_name()] = inherited_input_time;
895 }
896
897 // The output time is the sum of output times of inputs.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const898 void OutputTimeLocked(const NodeValues& input_times,
899 ParameterGradients* gradients, NodeValues* output_times,
900 NodeValues* output_time_gradients) const override
901 TF_SHARED_LOCKS_REQUIRED(mu_) {
902 (*output_times)[long_name()] = OutputTimeForInputs(*output_times);
903 if (gradients) {
904 (*output_time_gradients)[long_name()] =
905 OutputTimeGradientsForInputs(*output_time_gradients);
906 }
907 }
908
909 // The processing time is the sum of processing times of inputs.
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)910 void TotalProcessingTimeLocked(NodeValues* processing_times,
911 NodeValues* total_processing_times) override
912 TF_SHARED_LOCKS_REQUIRED(mu_) {
913 if (processing_times) {
914 (*processing_times)[long_name()] = SelfProcessingTimeLocked();
915 }
916 (*total_processing_times)[long_name()] =
917 TotalProcessingTimeForInputs(*total_processing_times);
918 }
919
ToProto(ModelProto::Node * node_proto) const920 Status ToProto(ModelProto::Node* node_proto) const {
921 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
922 node_proto->set_node_class(NodeClass::UNKNOWN);
923 return Status::OK();
924 }
925 };
926
927 } // namespace
928
929 thread_local int64_t Node::work_start_;
930
MakeParameter(const string & name,std::shared_ptr<SharedState> state,double min,double max)931 std::shared_ptr<Parameter> MakeParameter(const string& name,
932 std::shared_ptr<SharedState> state,
933 double min, double max) {
934 return std::make_shared<Parameter>(name, state, min, max);
935 }
936
MakeInterleaveManyNode(Node::Args args)937 std::shared_ptr<Node> MakeInterleaveManyNode(Node::Args args) {
938 return std::make_shared<InterleaveMany>(std::move(args));
939 }
940
MakeAsyncInterleaveManyNode(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)941 std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
942 Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters) {
943 return std::make_shared<AsyncInterleaveMany>(std::move(args),
944 std::move(parameters));
945 }
946
MakeKnownRatioNode(Node::Args args,double ratio)947 std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio) {
948 return std::make_shared<KnownRatio>(std::move(args), ratio);
949 }
950
MakeAsyncKnownRatioNode(Node::Args args,double ratio,double memory_ratio,std::vector<std::shared_ptr<Parameter>> parameters)951 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
952 Node::Args args, double ratio, double memory_ratio,
953 std::vector<std::shared_ptr<Parameter>> parameters) {
954 return std::make_shared<AsyncKnownRatio>(std::move(args), ratio, memory_ratio,
955 std::move(parameters));
956 }
957
MakeAsyncKnownRatioNode(Node::Args args,double ratio,std::vector<std::shared_ptr<Parameter>> parameters)958 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
959 Node::Args args, double ratio,
960 std::vector<std::shared_ptr<Parameter>> parameters) {
961 return MakeAsyncKnownRatioNode(std::move(args), /*ratio=*/ratio,
962 /*memory_ratio=*/ratio, std::move(parameters));
963 }
964
MakeSourceNode(Node::Args args)965 std::shared_ptr<Node> MakeSourceNode(Node::Args args) {
966 return MakeKnownRatioNode(std::move(args), 0);
967 }
968
MakeUnknownRatioNode(Node::Args args)969 std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args) {
970 return std::make_shared<UnknownRatio>(std::move(args));
971 }
972
MakeUnknownNode(Node::Args args)973 std::shared_ptr<Node> MakeUnknownNode(Node::Args args) {
974 return std::make_shared<Unknown>(std::move(args));
975 }
976
ComputeWaitTime(const double & producer_time,const double & consumer_time,const double & buffer_size,double * producer_time_derivative,double * consumer_time_derivative,double * buffer_size_derivative)977 double Node::ComputeWaitTime(const double& producer_time,
978 const double& consumer_time,
979 const double& buffer_size,
980 double* producer_time_derivative,
981 double* consumer_time_derivative,
982 double* buffer_size_derivative) {
983 // If we set x=`consumer_time`, y=`producer_time`, n=`buffer_size`,
984 // p=`p_buffer_empty`, T=`wait_time`, then we have:
985 // if y = 0, then p = 0;
986 // elif x = 0, then p = 1;
987 // elif x = y, then p = 1 / (n+1);
988 // else p = [1 - x/y] / [1 - power(x/y, n+1)].
989 //
990 // We also have T = p * y, and derivatives of T w.r.t. x, y, n are computed:
991 // dT/dx = dp/dx * y,
992 // dT/dy = p + dp/dy * y,
993 // dT/dn = dp/dn * y.
994 // Then the remaining work is to compute dp/dx, dp/dy, dp/dn by considering
995 // different cases and substitute the values into above formulas.
996
997 // Case 1: if producer is infinitely fast. The buffer will always be full.
998 // Wait time will always be 0.
999 if (producer_time == 0) {
1000 if (producer_time_derivative) {
1001 // Note a common error is `*producer_time_derivative = 0` since p=0 on the
1002 // line y=0 doesn't imply dp/dy = 0 there. Actually to compute dp/dy at
1003 // (x,0), we need to consider lim_{dy->0+} [p(x,dy)-p(x,0)] / dy, where
1004 // p(x,0)=0 and p(x,dy) = [1 - x/dy] / [1 - power(x/dy, n+1)].
1005 if (buffer_size == 0 || consumer_time == 0) {
1006 *producer_time_derivative = 1.0L;
1007 } else {
1008 *producer_time_derivative = 0.0L;
1009 }
1010 }
1011 if (consumer_time_derivative) {
1012 *consumer_time_derivative = 0.0L;
1013 }
1014 if (buffer_size_derivative) {
1015 *buffer_size_derivative = 0.0L;
1016 }
1017 return 0.0L;
1018 }
1019
1020 // Case 2: if consumer is infinitely fast. Wait time is always the time to
1021 // produce an output.
1022 if (consumer_time == 0) {
1023 if (producer_time_derivative) {
1024 *producer_time_derivative = 1.0L;
1025 }
1026 if (consumer_time_derivative) {
1027 // Note a common error is `*consumer_time_derivative = 0` since p=1 on the
1028 // line x=0 doesn't imply dp/dx = 0 there. Actually to compute dp/dx at
1029 // (0,y), we need to consider lim_{dx->0+} [p(dx,y)-p(0,y)] / dx, where
1030 // p(0,y)=1, p(dx,y) = [1 - dx/y] / [1 - power(dx/y, n+1)] if y!=0.
1031 if (buffer_size == 0) {
1032 *consumer_time_derivative = 0.0L;
1033 } else {
1034 *consumer_time_derivative = -1.0L;
1035 }
1036 }
1037 if (buffer_size_derivative) {
1038 *buffer_size_derivative = 0.0L;
1039 }
1040 return producer_time;
1041 }
1042
1043 // Case 3: the consumer and the producer are equally fast. Expected wait time
1044 // decreases linearly with the size of the buffer.
1045 if (consumer_time == producer_time) {
1046 const double p_buffer_empty = 1.0L / (buffer_size + 1.0L);
1047 const double p_buffer_empty_der =
1048 -buffer_size / (2.0L * buffer_size + 2.0L);
1049 if (producer_time_derivative) {
1050 // Note a common error is `*producer_time_derivative = p_buffer_empty`
1051 // since p=1/(n+1) on the line x=y doesn't imply dp/dy = 0 there. Actually
1052 // to compute dp/dy at (y,y), we need to consider
1053 // lim_{dy->0} [p(y,y+dy)-p(y,y)] / dy, where p(y,y)=1/(n+1),
1054 // p(y,y+dy) = [1 - y/(y+dy)] / [1 - power(y/(y+dy), n+1)].
1055 *producer_time_derivative = p_buffer_empty - p_buffer_empty_der;
1056 }
1057 if (consumer_time_derivative) {
1058 // Note a common error is `*consumer_time_derivative = 0` since
1059 // p=1/(n+1) on the line x=y doesn't imply dp/dx = 0 there. Actually to
1060 // compute dp/dx at (x,x), we need to consider
1061 // lim_{dx->0} [p(x+dx,x)-p(x,x)] / dx, where p(x,x)=1/(n+1),
1062 // p(x+dx,x) = [1 - (x+dx)/x] / [1 - power((x+dx)/x, n+1)].
1063 *consumer_time_derivative = p_buffer_empty_der;
1064 }
1065 if (buffer_size_derivative) {
1066 *buffer_size_derivative = -producer_time / Square(buffer_size + 1.0L);
1067 }
1068 return p_buffer_empty * producer_time;
1069 }
1070
1071 // Case 4: the consumer is slower than the producer and neither is infinitely
1072 // fast. Case 4 and Case 5 actually follow same formula. Separate them for
1073 // numerical computation reasons.
1074 if (consumer_time > producer_time) {
1075 const double ratio = producer_time / consumer_time;
1076 const double ratio_pow = std::pow(ratio, buffer_size);
1077 const double p_buffer_empty =
1078 ratio_pow * (1.0L - ratio) / (1.0L - ratio * ratio_pow);
1079 const double p_buffer_empty_der =
1080 (buffer_size - (buffer_size + 1.0L) * ratio + ratio_pow * ratio) *
1081 ratio_pow / ratio / Square(1.0L - ratio_pow * ratio);
1082 if (producer_time_derivative) {
1083 *producer_time_derivative = p_buffer_empty + p_buffer_empty_der * ratio;
1084 }
1085 if (consumer_time_derivative) {
1086 *consumer_time_derivative = -p_buffer_empty_der * Square(ratio);
1087 }
1088 if (buffer_size_derivative) {
1089 *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1090 std::log(ratio) * producer_time;
1091 }
1092 return p_buffer_empty * producer_time;
1093 }
1094
1095 // Case 5: the producer is slower than the consumer and neither is infinitely
1096 // fast.
1097 const double ratio = consumer_time / producer_time;
1098 const double ratio_pow = std::pow(ratio, buffer_size);
1099 const double p_buffer_empty = (1.0L - ratio) / (1.0L - ratio_pow * ratio);
1100 const double p_buffer_empty_der =
1101 ((buffer_size + 1.0L - buffer_size * ratio) * ratio_pow - 1.0L) /
1102 Square(1.0L - ratio_pow * ratio);
1103 if (producer_time_derivative) {
1104 *producer_time_derivative = p_buffer_empty - p_buffer_empty_der * ratio;
1105 }
1106 if (consumer_time_derivative) {
1107 *consumer_time_derivative = p_buffer_empty_der;
1108 }
1109 if (buffer_size_derivative) {
1110 *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1111 ratio_pow * ratio * std::log(ratio) *
1112 producer_time;
1113 }
1114 return p_buffer_empty * producer_time;
1115 }
1116
CollectTunableParametersLocked() const1117 Node::ModelParameters Node::CollectTunableParametersLocked() const {
1118 Node::ModelParameters parameters;
1119 // Collect tunable parameters from the leaves of the nodes tree to the root.
1120 for (const auto& node :
1121 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1122 tf_shared_lock l(node->mu_);
1123 node->CollectTunableParametersHelper(¶meters);
1124 }
1125 CollectTunableParametersHelper(¶meters);
1126 return parameters;
1127 }
1128
CollectTunableParameters() const1129 Node::ModelParameters Node::CollectTunableParameters() const {
1130 tf_shared_lock l(mu_);
1131 return CollectTunableParametersLocked();
1132 }
1133
DebugString() const1134 string Node::DebugString() const {
1135 absl::flat_hash_map<string, string> debug_strings;
1136 tf_shared_lock l(mu_);
1137 // Build up the debug string from the leaves of the nodes tree to the root.
1138 for (const auto& node :
1139 CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1140 tf_shared_lock l(node->mu_);
1141 node->DebugStringHelper(&debug_strings);
1142 }
1143 DebugStringHelper(&debug_strings);
1144
1145 return debug_strings[long_name()];
1146 }
1147
FlushMetrics()1148 void Node::FlushMetrics() {
1149 if (!record_metrics_) {
1150 return;
1151 }
1152 metrics_.record_bytes_consumed(bytes_consumed_);
1153 metrics_.record_bytes_produced(bytes_produced_);
1154 metrics_.record_num_elements(num_elements_);
1155 }
1156
OutputTime(Node::NodeValues * input_times,Node::ParameterGradients * gradients) const1157 double Node::OutputTime(Node::NodeValues* input_times,
1158 Node::ParameterGradients* gradients) const {
1159 // To store the output time gradient w.r.t. input time (if `gradients` is not
1160 // `nullptr`) and the output time for each node.
1161 Node::NodeValues output_time_gradients, output_times;
1162 tf_shared_lock l(mu_);
1163 auto nodes = CollectNodes(TraversalOrder::BFS, IsAutotuneNode);
1164
1165 // Computes and stores input time for each node from the root to leaves of the
1166 // nodes tree.
1167 InputTimeLocked(input_times);
1168 for (const auto& node : nodes) {
1169 tf_shared_lock l(node->mu_);
1170 node->InputTimeLocked(input_times);
1171 }
1172
1173 std::reverse(nodes.begin(), nodes.end());
1174 // Computes and stores the output time and output time gradient w.r.t. input
1175 // time (if `gradients` is not `nullptr`) for each node from leaves of the
1176 // nodes tree to the root.
1177 for (const auto& node : nodes) {
1178 tf_shared_lock l(node->mu_);
1179 node->OutputTimeLocked(*input_times, gradients, &output_times,
1180 &output_time_gradients);
1181 }
1182 OutputTimeLocked(*input_times, gradients, &output_times,
1183 &output_time_gradients);
1184
1185 return output_times[long_name()];
1186 }
1187
Snapshot() const1188 std::shared_ptr<Node> Node::Snapshot() const {
1189 NodePairList node_pairs;
1190 auto result = SnapshotHelper(nullptr, &node_pairs);
1191
1192 while (!node_pairs.empty()) {
1193 auto node_pair = node_pairs.front();
1194 node_pairs.pop_front();
1195 std::shared_ptr<Node> current = node_pair.first,
1196 cloned_output = node_pair.second;
1197 cloned_output->add_input(
1198 current->SnapshotHelper(cloned_output, &node_pairs));
1199 }
1200 return result;
1201 }
1202
SelfProcessingTime() const1203 double Node::SelfProcessingTime() const {
1204 tf_shared_lock l(mu_);
1205 return SelfProcessingTimeLocked();
1206 }
1207
TotalBufferedBytes() const1208 double Node::TotalBufferedBytes() const {
1209 Node::NodeValues total_bytes;
1210 tf_shared_lock l(mu_);
1211 // Compute total buffered bytes from the leaves of the nodes tree to the root.
1212 for (const auto& node :
1213 CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1214 tf_shared_lock l(node->mu_);
1215 node->TotalBufferedBytesHelper(&total_bytes);
1216 }
1217 TotalBufferedBytesHelper(&total_bytes);
1218
1219 return total_bytes[long_name()];
1220 }
1221
TotalMaximumBufferedBytes() const1222 double Node::TotalMaximumBufferedBytes() const {
1223 Node::NodeValues total_bytes;
1224 tf_shared_lock l(mu_);
1225 // Compute total maximum buffered bytes from the leaves of the nodes tree
1226 // to the root.
1227 for (const auto& node :
1228 CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1229 tf_shared_lock l(node->mu_);
1230 node->TotalMaximumBufferedBytesHelper(&total_bytes);
1231 }
1232 TotalMaximumBufferedBytesHelper(&total_bytes);
1233
1234 return total_bytes[long_name()];
1235 }
1236
TotalProcessingTime(Node::NodeValues * processing_times)1237 double Node::TotalProcessingTime(Node::NodeValues* processing_times) {
1238 // Create a hash map to store the per-element CPU time spent in the subtree
1239 // rooted in each node.
1240 Node::NodeValues total_processing_times;
1241 tf_shared_lock l(mu_);
1242
1243 // Computes per-element CPU time spent in the subtree rooted in the node from
1244 // the leaves of the nodes tree to the root.
1245 for (const auto& node :
1246 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1247 tf_shared_lock l(node->mu_);
1248 node->TotalProcessingTimeLocked(processing_times, &total_processing_times);
1249 }
1250 TotalProcessingTimeLocked(processing_times, &total_processing_times);
1251
1252 return total_processing_times[long_name()];
1253 }
1254
AverageBufferedElementSize() const1255 double Node::AverageBufferedElementSize() const {
1256 DCHECK_GE(num_elements_, 0);
1257 DCHECK_GE(buffered_elements_, 0);
1258 if (num_elements_ <= 0) {
1259 if (buffered_elements_ <= 0) {
1260 // If there are no produced elements or buffered elements recorded, return
1261 // 0.
1262 return 0;
1263 }
1264 // If there are no produced elements but some buffered elements, return the
1265 // average size of all buffered elements.
1266 return static_cast<double>(buffered_bytes_) /
1267 static_cast<double>(buffered_elements_);
1268 }
1269
1270 if (buffered_elements_ <= 0) {
1271 // If there are no buffered elements but some produced elements, return the
1272 // average size of all produced elements.
1273 return static_cast<double>(bytes_produced_) /
1274 static_cast<double>(num_elements_);
1275 }
1276
1277 // Otherwise, return the mean value of average size of all produced elements
1278 // and average size of all buffered elements.
1279 return (static_cast<double>(bytes_produced_) /
1280 static_cast<double>(num_elements_) +
1281 static_cast<double>(buffered_bytes_) /
1282 static_cast<double>(buffered_elements_)) /
1283 2.0;
1284 }
1285
OutputTimeForInputs(const Node::NodeValues & output_times) const1286 double Node::OutputTimeForInputs(const Node::NodeValues& output_times) const {
1287 double sum = 0;
1288 for (auto& input : inputs_) {
1289 // Inputs for which autotuning is disabled are excluded.
1290 if (input->autotune()) {
1291 sum += output_times.at(input->long_name());
1292 }
1293 }
1294 return sum;
1295 }
1296
OutputTimeGradientsForInputs(const Node::NodeValues & output_time_gradients) const1297 double Node::OutputTimeGradientsForInputs(
1298 const Node::NodeValues& output_time_gradients) const {
1299 double sum = 0;
1300 for (auto& input : inputs_) {
1301 // Inputs for which autotuning is disabled are excluded.
1302 if (input->autotune()) {
1303 sum +=
1304 gtl::FindWithDefault(output_time_gradients, input->long_name(), 0.0L);
1305 }
1306 }
1307 return sum;
1308 }
1309
TotalProcessingTimeForInputs(const Node::NodeValues & total_processing_times)1310 double Node::TotalProcessingTimeForInputs(
1311 const Node::NodeValues& total_processing_times) {
1312 // If the number of elements produced by an input is smaller than this
1313 // constant, then its processing time is estimated using a weighted average
1314 // of the empirical processing time and processing time history.
1315 constexpr int kNumElementsThreshold = 30;
1316
1317 // Identifies the minimum number of input processing times to collect
1318 // before the processing time history is used as a prior.
1319 constexpr int kCountThreshold = 30;
1320
1321 double sum = 0;
1322 for (auto& input : inputs_) {
1323 // Inputs for which autotuning is disabled are excluded.
1324 if (input->autotune()) {
1325 double input_processing_time =
1326 total_processing_times.at(input->long_name());
1327 int64_t num_elements = input->num_elements();
1328 if (num_elements < kNumElementsThreshold) {
1329 if (input_processing_time_count_ < kCountThreshold) {
1330 sum += input_processing_time;
1331 } else {
1332 // The fewer elements the input has produced so far, the more weight
1333 // is assigned to the prior to reduce volatility.
1334 double prior_weight = 1.0L / static_cast<double>(2 << num_elements);
1335 double prior =
1336 input_processing_time_sum_ / input_processing_time_count_;
1337 sum += (1.0L - prior_weight) * input_processing_time +
1338 prior_weight * prior;
1339 }
1340 } else {
1341 sum += input_processing_time;
1342 input_processing_time_count_++;
1343 input_processing_time_sum_ += input_processing_time;
1344 }
1345 }
1346 }
1347 return sum;
1348 }
1349
SelfProcessingTimeLocked() const1350 double Node::SelfProcessingTimeLocked() const {
1351 if (num_elements_ == 0) {
1352 return 0;
1353 }
1354 return static_cast<double>(processing_time_) /
1355 static_cast<double>(num_elements_);
1356 }
1357
CollectNodes(TraversalOrder order,bool collect_node (const std::shared_ptr<Node>)) const1358 Node::NodeVector Node::CollectNodes(
1359 TraversalOrder order, bool collect_node(const std::shared_ptr<Node>)) const
1360 TF_SHARED_LOCKS_REQUIRED(mu_) {
1361 NodeVector node_vector;
1362 std::list<std::shared_ptr<Node>> temp_list;
1363
1364 for (auto& input : inputs_) {
1365 if (collect_node(input)) {
1366 node_vector.push_back(input);
1367 temp_list.push_back(input);
1368 }
1369 }
1370
1371 while (!temp_list.empty()) {
1372 auto cur_node = temp_list.front();
1373 temp_list.pop_front();
1374 tf_shared_lock l(cur_node->mu_);
1375 for (auto& input : cur_node->inputs_) {
1376 if (collect_node(input)) {
1377 node_vector.push_back(input);
1378 temp_list.push_back(input);
1379 }
1380 }
1381 }
1382
1383 if (order == TraversalOrder::REVERSE_BFS) {
1384 std::reverse(node_vector.begin(), node_vector.end());
1385 }
1386 return node_vector;
1387 }
1388
CollectTunableParametersHelper(Node::ModelParameters * parameters) const1389 void Node::CollectTunableParametersHelper(
1390 Node::ModelParameters* parameters) const TF_SHARED_LOCKS_REQUIRED(mu_) {
1391 // If autotune is turned off or there are no elements recorded, we don't
1392 // collect the parameters on the node.
1393 if (!autotune_ || num_elements_ <= 0) {
1394 return;
1395 }
1396 for (auto& pair : parameters_) {
1397 if (pair.second->state->tunable) {
1398 parameters->push_back(std::make_pair(long_name(), pair.second));
1399 }
1400 }
1401 }
1402
DebugStringHelper(absl::flat_hash_map<string,string> * debug_strings) const1403 void Node::DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings)
1404 const TF_SHARED_LOCKS_REQUIRED(mu_) {
1405 string result;
1406 strings::StrAppend(&result, long_name(), ":\n");
1407 strings::StrAppend(&result, " autotune=", autotune_.load(), "\n");
1408 strings::StrAppend(&result, " buffered_bytes=", buffered_bytes_.load(),
1409 "\n");
1410 strings::StrAppend(&result, " buffered_elements=", buffered_elements_.load(),
1411 "\n");
1412 strings::StrAppend(&result, " bytes_consumed=", bytes_consumed_.load(),
1413 "\n");
1414 strings::StrAppend(&result, " bytes_produced=", bytes_produced_.load(),
1415 "\n");
1416 strings::StrAppend(&result, " processing_time=", processing_time_.load(),
1417 "\n");
1418 strings::StrAppend(&result, " num_elements=", num_elements_.load(), "\n");
1419 string inputs;
1420 for (auto& input : inputs_) {
1421 strings::StrAppend(&inputs, input->long_name(), ",");
1422 }
1423 strings::StrAppend(&result, " inputs={", inputs, "}\n");
1424 for (auto& input : inputs_) {
1425 strings::StrAppend(&result, debug_strings->at(input->long_name()));
1426 }
1427 debug_strings->insert(std::make_pair(long_name(), result));
1428 }
1429
SnapshotHelper(std::shared_ptr<Node> cloned_output,Node::NodePairList * node_pairs) const1430 std::shared_ptr<Node> Node::SnapshotHelper(
1431 std::shared_ptr<Node> cloned_output, Node::NodePairList* node_pairs) const {
1432 tf_shared_lock l(mu_);
1433
1434 // Clone current node(`this`), also set clone of its output node
1435 // (`cloned_output`) to be the output node of the cloned node
1436 // (`cloned_current`).
1437 std::shared_ptr<Node> cloned_current = Clone(cloned_output);
1438 {
1439 cloned_current->autotune_.store(autotune_);
1440 cloned_current->buffered_bytes_.store(buffered_bytes_);
1441 cloned_current->buffered_elements_.store(buffered_elements_);
1442 cloned_current->bytes_consumed_.store(bytes_consumed_);
1443 cloned_current->bytes_produced_.store(bytes_produced_);
1444 cloned_current->num_elements_.store(num_elements_);
1445 cloned_current->record_metrics_.store(false);
1446 cloned_current->processing_time_.store(processing_time_);
1447 mutex_lock l2(cloned_current->mu_);
1448 cloned_current->parameters_ = parameters_;
1449 }
1450
1451 for (auto& input : inputs_) {
1452 node_pairs->push_back(std::make_pair(input, cloned_current));
1453 }
1454 return cloned_current;
1455 }
1456
TotalBufferedBytesHelper(Node::NodeValues * total_bytes) const1457 void Node::TotalBufferedBytesHelper(Node::NodeValues* total_bytes) const
1458 TF_SHARED_LOCKS_REQUIRED(mu_) {
1459 if (!autotune_) {
1460 total_bytes->insert(std::make_pair(long_name(), 0));
1461 return;
1462 }
1463
1464 double result = 0;
1465 auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
1466 if (!parameter) {
1467 parameter = gtl::FindOrNull(parameters_, kParallelism);
1468 }
1469 if (parameter) {
1470 result = buffered_bytes_;
1471 }
1472 for (auto& input : inputs_) {
1473 result += total_bytes->at(input->long_name());
1474 }
1475 total_bytes->insert(std::make_pair(long_name(), result));
1476 }
1477
TotalMaximumBufferedBytesHelper(Node::NodeValues * total_bytes) const1478 void Node::TotalMaximumBufferedBytesHelper(Node::NodeValues* total_bytes) const
1479 TF_SHARED_LOCKS_REQUIRED(mu_) {
1480 if (!autotune_) {
1481 total_bytes->insert(std::make_pair(long_name(), 0));
1482 return;
1483 }
1484
1485 double result = MaximumBufferedBytes();
1486 for (auto& input : inputs_) {
1487 result += total_bytes->at(input->long_name());
1488 }
1489 total_bytes->insert(std::make_pair(long_name(), result));
1490 }
1491
MaximumBufferedBytes() const1492 double Node::MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
1493 return 0;
1494 }
1495
ToProto(ModelProto::Node * node_proto) const1496 Status Node::ToProto(ModelProto::Node* node_proto) const {
1497 tf_shared_lock l(mu_);
1498 node_proto->set_id(id_);
1499 node_proto->set_name(name_);
1500 node_proto->set_autotune(autotune_);
1501 node_proto->set_buffered_bytes(buffered_bytes_);
1502 node_proto->set_buffered_elements(buffered_elements_);
1503 node_proto->set_bytes_consumed(bytes_consumed_);
1504 node_proto->set_bytes_produced(bytes_produced_);
1505 node_proto->set_num_elements(num_elements_);
1506 node_proto->set_processing_time(processing_time_);
1507 node_proto->set_record_metrics(record_metrics_);
1508
1509 // Produce protos for all parameters.
1510 for (auto const& parameter : parameters_) {
1511 ModelProto::Node::Parameter* parameter_proto = node_proto->add_parameters();
1512 parameter_proto->set_name(parameter.first);
1513 parameter_proto->set_value(parameter.second->value);
1514 parameter_proto->set_min(parameter.second->min);
1515 parameter_proto->set_max(parameter.second->max);
1516 parameter_proto->set_state_value(parameter.second->state->value);
1517 parameter_proto->set_tunable(parameter.second->state->tunable);
1518 }
1519
1520 // Add input node ids.
1521 for (auto const& input : inputs_) {
1522 node_proto->add_inputs(input->id());
1523 }
1524 return Status::OK();
1525 }
1526
FromProtoHelper(ModelProto::Node node_proto,std::shared_ptr<Node> node)1527 Status Node::FromProtoHelper(ModelProto::Node node_proto,
1528 std::shared_ptr<Node> node) {
1529 tf_shared_lock l(node->mu_);
1530 node->autotune_.store(node_proto.autotune());
1531 node->buffered_bytes_.store(node_proto.buffered_bytes());
1532 node->buffered_elements_.store(node_proto.buffered_elements());
1533 node->bytes_consumed_.store(node_proto.bytes_consumed());
1534 node->bytes_produced_.store(node_proto.bytes_produced());
1535 node->num_elements_.store(node_proto.num_elements());
1536 node->processing_time_.store(node_proto.processing_time());
1537 node->record_metrics_.store(node_proto.record_metrics());
1538
1539 // Restore parameters.
1540 int64_t num_parameters = node_proto.parameters_size();
1541 for (int i = 0; i < num_parameters; i++) {
1542 const ModelProto::Node::Parameter& parameter_proto =
1543 node_proto.parameters(i);
1544 std::shared_ptr<SharedState> state;
1545 if (parameter_proto.tunable()) {
1546 state =
1547 std::make_shared<SharedState>(kAutotune, std::make_shared<mutex>(),
1548 std::make_shared<condition_variable>());
1549 state->value = parameter_proto.state_value();
1550 } else {
1551 state = std::make_shared<SharedState>(
1552 parameter_proto.state_value(), std::make_shared<mutex>(),
1553 std::make_shared<condition_variable>());
1554 }
1555 node->parameters_[parameter_proto.name()] =
1556 MakeParameter(parameter_proto.name(), state, parameter_proto.min(),
1557 parameter_proto.max());
1558 }
1559 return Status::OK();
1560 }
1561
FromProto(ModelProto::Node node_proto,std::shared_ptr<Node> output,std::shared_ptr<Node> * node)1562 Status Node::FromProto(ModelProto::Node node_proto,
1563 std::shared_ptr<Node> output,
1564 std::shared_ptr<Node>* node) {
1565 // Note that parameters are restored in `FromProtoHelper`.
1566 Args args = {node_proto.id(), node_proto.name(), std::move(output)};
1567 switch (node_proto.node_class()) {
1568 case NodeClass::INTERLEAVE_MANY:
1569 *node = std::make_shared<InterleaveMany>(args);
1570 break;
1571 case NodeClass::ASYNC_INTERLEAVE_MANY:
1572 *node = std::make_shared<AsyncInterleaveMany>(
1573 args, /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
1574 break;
1575 case NodeClass::KNOWN_RATIO:
1576 *node = std::make_shared<KnownRatio>(args, node_proto.ratio());
1577 break;
1578 case NodeClass::ASYNC_KNOWN_RATIO:
1579 *node = std::make_shared<AsyncKnownRatio>(
1580 args, node_proto.ratio(), node_proto.memory_ratio(),
1581 /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
1582 break;
1583 case NodeClass::UNKNOWN_RATIO:
1584 *node = std::make_shared<UnknownRatio>(args);
1585 break;
1586 default:
1587 *node = std::make_shared<Unknown>(args);
1588 }
1589 return FromProtoHelper(node_proto, *node);
1590 }
1591
Model()1592 Model::Model()
1593 : collect_resource_usage_(false),
1594 optimization_period_ms_(kOptimizationPeriodMinMs) {
1595 model_gauge_cell_ = metrics::GetTFDataModelGauge(
1596 strings::StrCat(reinterpret_cast<uint64>(this)));
1597 model_gauge_cell_->Set([&]() { return DebugString(); });
1598 }
1599
~Model()1600 Model::~Model() {
1601 // Before the model is destroyed, we record its final state in the gauge.
1602 auto result = DebugString();
1603 model_gauge_cell_->Set([result]() { return result; });
1604 }
1605
AddNode(Node::Factory factory,const string & name,std::shared_ptr<Node> parent,std::shared_ptr<Node> * out_node)1606 void Model::AddNode(Node::Factory factory, const string& name,
1607 std::shared_ptr<Node> parent,
1608 std::shared_ptr<Node>* out_node) {
1609 // The name captures the sequence of iterators joined by `::`. We only use the
1610 // last element of the sequence as the name node.
1611 auto node_name = str_util::Split(name, ':', str_util::SkipEmpty()).back();
1612 mutex_lock l(mu_);
1613 std::shared_ptr<Node> node = factory({id_counter_++, node_name, parent});
1614 if (!output_) {
1615 output_ = node;
1616 }
1617 if (parent) {
1618 VLOG(3) << "Adding " << node->long_name() << " as input for "
1619 << parent->long_name();
1620 parent->add_input(node);
1621 } else {
1622 VLOG(3) << "Adding " << node->long_name();
1623 }
1624 collect_resource_usage_ =
1625 collect_resource_usage_ || node->has_tunable_parameters();
1626 *out_node = std::move(node);
1627 // TODO(jsimsa): Reset the optimization period when a node is added so that
1628 // autotuning adapts to changes to the input pipeline faster. Initial attempt
1629 // to enable this functionality caused a regression (see b/179812091).
1630 }
1631
FlushMetrics()1632 void Model::FlushMetrics() {
1633 std::deque<std::shared_ptr<Node>> queue;
1634 {
1635 tf_shared_lock l(mu_);
1636 if (output_) queue.push_back(output_);
1637 }
1638 while (!queue.empty()) {
1639 auto node = queue.front();
1640 queue.pop_front();
1641 node->FlushMetrics();
1642 for (auto input : node->inputs()) {
1643 queue.push_back(input);
1644 }
1645 }
1646 }
1647
Optimize(AutotuneAlgorithm algorithm,int64_t cpu_budget,int64_t ram_budget,double model_input_time,CancellationManager * cancellation_manager)1648 void Model::Optimize(AutotuneAlgorithm algorithm, int64_t cpu_budget,
1649 int64_t ram_budget, double model_input_time,
1650 CancellationManager* cancellation_manager) {
1651 std::shared_ptr<Node> snapshot;
1652 {
1653 tf_shared_lock l(mu_);
1654 snapshot = output_->Snapshot();
1655 }
1656 OptimizationParams optimization_params;
1657 optimization_params.set_algorithm(algorithm);
1658 optimization_params.set_cpu_budget(cpu_budget);
1659 optimization_params.set_ram_budget(ram_budget);
1660 optimization_params.set_model_input_time(model_input_time);
1661 switch (algorithm) {
1662 case AutotuneAlgorithm::HILL_CLIMB:
1663 OptimizeHillClimb(snapshot, optimization_params, cancellation_manager);
1664 break;
1665 case AutotuneAlgorithm::GRADIENT_DESCENT:
1666 OptimizeGradientDescent(snapshot, optimization_params,
1667 cancellation_manager);
1668 break;
1669 default:
1670 VLOG(2) << "Autotuning algorithm was not recognized. Aborting "
1671 "optimization.";
1672 return;
1673 }
1674 }
1675
RemoveNode(std::shared_ptr<Node> node)1676 void Model::RemoveNode(std::shared_ptr<Node> node) {
1677 mutex_lock l(mu_);
1678 if (node) {
1679 if (node->output()) {
1680 node->output()->remove_input(node);
1681 }
1682 VLOG(3) << "Removing " << node->long_name();
1683 }
1684 }
1685
CollectTunableParameters(std::shared_ptr<Node> node)1686 Model::ModelParameters Model::CollectTunableParameters(
1687 std::shared_ptr<Node> node) {
1688 return node->CollectTunableParameters();
1689 }
1690
ShouldStop(int64_t cpu_budget,int64_t ram_budget,const Model::ModelParameters & parameters,const Model::ModelParameters & parallelism_parameters,const Model::ModelParameters & buffer_size_parameters,std::shared_ptr<Node> snapshot,bool * cpu_budget_reached)1691 bool Model::ShouldStop(int64_t cpu_budget, int64_t ram_budget,
1692 const Model::ModelParameters& parameters,
1693 const Model::ModelParameters& parallelism_parameters,
1694 const Model::ModelParameters& buffer_size_parameters,
1695 std::shared_ptr<Node> snapshot,
1696 bool* cpu_budget_reached) {
1697 if (!(*cpu_budget_reached)) {
1698 // If those essential transformations' parallelism reaches the CPU
1699 // budget, we will only tune the buffer size parameters in future
1700 // iterations.
1701 int64_t model_parallelism = 0;
1702 for (auto& pair : parallelism_parameters) {
1703 model_parallelism += std::round(pair.second->value);
1704 }
1705 *cpu_budget_reached = (model_parallelism > cpu_budget);
1706 }
1707
1708 bool all_max = true;
1709 for (auto& pair :
1710 (*cpu_budget_reached ? buffer_size_parameters : parameters)) {
1711 if (std::round(pair.second->value) < pair.second->max) {
1712 all_max = false;
1713 break;
1714 }
1715 }
1716
1717 // If all parameters have reached their maximum values or RAM budget is
1718 // reached, we stop the iterations.
1719 return all_max || TotalMaximumBufferedBytes(snapshot) > ram_budget;
1720 }
1721
1722 // TODO(jsimsa): Add support for tracking and using the model input time.
OptimizeLoop(AutotuneAlgorithm algorithm,int64_t cpu_budget,int64_t ram_budget,CancellationManager * cancellation_manager)1723 Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, int64_t cpu_budget,
1724 int64_t ram_budget,
1725 CancellationManager* cancellation_manager) {
1726 std::function<void()> unused;
1727 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
1728 cancellation_manager,
1729 [this]() {
1730 mutex_lock l(mu_);
1731 optimize_cond_var_.notify_all();
1732 },
1733 /*deregister_fn=*/&unused));
1734
1735 int64_t last_optimization_ms = 0;
1736 int64_t current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1737 while (true) {
1738 {
1739 mutex_lock l(mu_);
1740 while (!cancellation_manager->IsCancelled() &&
1741 last_optimization_ms + optimization_period_ms_ > current_time_ms) {
1742 auto wait_ms =
1743 last_optimization_ms + optimization_period_ms_ - current_time_ms;
1744 VLOG(2) << "Waiting for " << wait_ms << " ms.";
1745 optimize_cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
1746 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1747 }
1748 if (cancellation_manager->IsCancelled()) {
1749 return Status::OK();
1750 }
1751 }
1752
1753 int64_t start_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1754 Optimize(algorithm, cpu_budget, ram_budget, /*model_input_time=*/0,
1755 cancellation_manager);
1756 int64_t end_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1757 VLOG(2) << "Optimized for " << end_ms - start_ms << " ms.";
1758
1759 // Exponentially increase the period of running the optimization
1760 // until a threshold is reached.
1761 {
1762 mutex_lock l(mu_);
1763 optimization_period_ms_ =
1764 std::min(optimization_period_ms_ << 1, kOptimizationPeriodMaxMs);
1765 }
1766 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1767 last_optimization_ms = current_time_ms;
1768 FlushMetrics();
1769 }
1770 }
1771
OptimizeGradientDescent(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params,CancellationManager * cancellation_manager)1772 void Model::OptimizeGradientDescent(
1773 std::shared_ptr<Node> snapshot,
1774 const OptimizationParams& optimization_params,
1775 CancellationManager* cancellation_manager) {
1776 VLOG(2) << "Starting optimization of tunable parameters with Gradient "
1777 "Descent.";
1778 auto parameters = CollectTunableParameters(snapshot);
1779 if (parameters.empty()) {
1780 VLOG(2) << "The Gradient Descent optimization is terminated since no node "
1781 "with tunable parameters has recorded elements.";
1782 return;
1783 }
1784 VLOG(2) << "Number of tunable parameters: " << parameters.size();
1785
1786 // The vectors of "essential" parallelism parameters and buffer size
1787 // parameters.
1788 Model::ModelParameters parallelism_parameters, buffer_size_parameters;
1789 CollectParameters(snapshot, parameters, ¶llelism_parameters,
1790 &buffer_size_parameters);
1791
1792 // Initialize the parameter values to minimal before tuning.
1793 for (auto& pair : parameters) {
1794 pair.second->value = pair.second->min;
1795 }
1796
1797 // Optimization is stopped once the `OutputTime` improvement is smaller than
1798 // this value.
1799 constexpr double kOptimizationPrecision = 100.0L;
1800
1801 // Maximum number of iterations for optimization.
1802 constexpr int64_t kMaxIterations = 1000;
1803
1804 double output_time = 0;
1805 double new_output_time;
1806
1807 // When the CPU budget is reached, the parallelism parameter values are fixed
1808 // and we only increase the buffer size parameters.
1809 bool cpu_budget_reached = false;
1810
1811 for (int i = 0; i < kMaxIterations; ++i) {
1812 if (cancellation_manager->IsCancelled() ||
1813 ShouldStop(optimization_params.cpu_budget(),
1814 optimization_params.ram_budget(), parameters,
1815 parallelism_parameters, buffer_size_parameters, snapshot,
1816 &cpu_budget_reached)) {
1817 break;
1818 }
1819 Model::ParameterGradients gradients;
1820 new_output_time = OutputTime(
1821 snapshot, optimization_params.model_input_time(), &gradients);
1822 // We also terminate once the improvement of the output latency is too
1823 // small.
1824 if (std::abs(output_time - new_output_time) < kOptimizationPrecision) {
1825 break;
1826 }
1827
1828 UpdateParameterValues(
1829 gradients, &(cpu_budget_reached ? buffer_size_parameters : parameters));
1830 output_time = new_output_time;
1831 }
1832
1833 for (auto& pair : parameters) {
1834 pair.second->value = std::round(pair.second->value);
1835 }
1836 UpdateStateValues(¶meters);
1837 }
1838
OptimizeHillClimb(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params,CancellationManager * cancellation_manager)1839 void Model::OptimizeHillClimb(std::shared_ptr<Node> snapshot,
1840 const OptimizationParams& optimization_params,
1841 CancellationManager* cancellation_manager) {
1842 VLOG(2) << "Starting optimization of tunable parameters with Hill Climb.";
1843 const double processing_time = TotalProcessingTime(snapshot);
1844 auto parameters = CollectTunableParameters(snapshot);
1845 if (parameters.empty()) {
1846 VLOG(2) << "The Hill Climb optimization is terminated since no node with "
1847 "tunable parameters has recorded elements.";
1848 return;
1849 }
1850 VLOG(2) << "Number of tunable parameters: " << parameters.size();
1851
1852 // Buffer size parameter will only be incremented if the output latency
1853 // improvement is greater than this constant.
1854 constexpr double kBufferSizeMinDelta = 1.0L;
1855
1856 // Initialize the parameter values to minimal before tuning.
1857 for (auto& pair : parameters) {
1858 pair.second->value = pair.second->min;
1859 }
1860 while (!cancellation_manager->IsCancelled()) {
1861 const double output_time =
1862 OutputTime(snapshot, optimization_params.model_input_time(),
1863 /*gradients=*/nullptr);
1864 bool all_max = true;
1865 for (auto& pair : parameters) {
1866 if (pair.second->value < pair.second->max) {
1867 all_max = false;
1868 break;
1869 }
1870 }
1871 if (output_time < processing_time / optimization_params.cpu_budget() ||
1872 all_max ||
1873 TotalMaximumBufferedBytes(snapshot) >
1874 optimization_params.ram_budget()) {
1875 break;
1876 }
1877 double best_delta = -1.0L;
1878 Parameter* best_parameter = nullptr;
1879 for (auto& pair : parameters) {
1880 if (pair.second->value >= pair.second->max) {
1881 continue;
1882 }
1883 pair.second->value++;
1884 double new_output_time =
1885 OutputTime(snapshot, optimization_params.model_input_time(),
1886 /*gradients=*/nullptr);
1887 double delta = output_time - new_output_time;
1888 if (delta > best_delta &&
1889 (delta > kBufferSizeMinDelta || pair.second->name != kBufferSize)) {
1890 best_delta = delta;
1891 best_parameter = pair.second.get();
1892 }
1893 pair.second->value--;
1894 }
1895 if (!best_parameter) {
1896 VLOG(2) << "Failed to find a tunable parameter that would further "
1897 "decrease the output time. This means that the autotuning "
1898 "optimization got stuck in a local maximum. The optimization "
1899 "attempt will terminate early.";
1900 break;
1901 }
1902 best_parameter->value++;
1903 }
1904 UpdateStateValues(¶meters);
1905 }
1906
OutputTime(std::shared_ptr<Node> node,double model_input_time,Model::ParameterGradients * gradients)1907 double Model::OutputTime(std::shared_ptr<Node> node, double model_input_time,
1908 Model::ParameterGradients* gradients) {
1909 // To store the input time for each node.
1910 Model::NodeValues input_times = {{kModelInputTimeKey, model_input_time}};
1911
1912 // TODO(jsimsa): Now that we are accounting for buffer size in wait time
1913 // computation, assuming that the input is infinitely fast will result in
1914 // inaccurate estimates of the output latency.
1915 //
1916 // We should compute the output latency as a fix-point of the following
1917 // equation: `output_time = node(OutputTime(input_times(1, output_time))`.
1918
1919 return node->OutputTime(&input_times, gradients);
1920 }
1921
TotalBufferedBytes(std::shared_ptr<Node> node)1922 double Model::TotalBufferedBytes(std::shared_ptr<Node> node) {
1923 return node->TotalBufferedBytes();
1924 }
1925
TotalMaximumBufferedBytes(std::shared_ptr<Node> node)1926 double Model::TotalMaximumBufferedBytes(std::shared_ptr<Node> node) {
1927 return node->TotalMaximumBufferedBytes();
1928 }
1929
TotalProcessingTime(std::shared_ptr<Node> node)1930 double Model::TotalProcessingTime(std::shared_ptr<Node> node) {
1931 return node->TotalProcessingTime(/*processing_times=*/nullptr);
1932 }
1933
ToProto(ModelProto * model_proto)1934 Status Model::ToProto(ModelProto* model_proto) {
1935 tf_shared_lock l(mu_);
1936 model_proto->set_id_counter(id_counter_);
1937 model_proto->set_collect_resource_usage(collect_resource_usage_);
1938 return ModelToProtoHelper(output_, model_proto);
1939 }
1940
FromProto(ModelProto model_proto,std::unique_ptr<Model> * model)1941 Status Model::FromProto(ModelProto model_proto, std::unique_ptr<Model>* model) {
1942 std::unique_ptr<Model> restored_model = std::make_unique<Model>();
1943 mutex_lock l(restored_model->mu_);
1944 TF_RETURN_IF_ERROR(
1945 ModelFromProtoHelper(model_proto, &restored_model->output_));
1946 restored_model->id_counter_ = model_proto.id_counter();
1947 restored_model->collect_resource_usage_.store(
1948 model_proto.collect_resource_usage());
1949 *model = std::move(restored_model);
1950 return Status::OK();
1951 }
1952
Save(const string & fname,std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params)1953 Status Model::Save(const string& fname, std::shared_ptr<Node> snapshot,
1954 const OptimizationParams& optimization_params) {
1955 ModelProto model_proto;
1956 std::unique_ptr<Model> model_snapshot = std::make_unique<Model>();
1957 {
1958 mutex_lock l(model_snapshot->mu_);
1959 model_snapshot->output_ = std::move(snapshot);
1960 model_snapshot->id_counter_ = id_counter_;
1961 model_snapshot->collect_resource_usage_.store(collect_resource_usage_);
1962 }
1963 TF_RETURN_IF_ERROR(model_snapshot->ToProto(&model_proto));
1964 OptimizationParams* saved_optimization_params =
1965 model_proto.mutable_optimization_params();
1966 *saved_optimization_params = optimization_params;
1967 return WriteBinaryProto(Env::Default(), fname, model_proto);
1968 }
1969
Load(const string & fname,std::unique_ptr<Model> * model,OptimizationParams * optimization_params)1970 Status Model::Load(const string& fname, std::unique_ptr<Model>* model,
1971 OptimizationParams* optimization_params) {
1972 ModelProto model_proto;
1973 TF_RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), fname, &model_proto));
1974 TF_RETURN_IF_ERROR(FromProto(model_proto, model));
1975 const OptimizationParams restored_optimization_params =
1976 model_proto.optimization_params();
1977 *optimization_params = restored_optimization_params;
1978 return Status::OK();
1979 }
1980
DebugString()1981 std::string Model::DebugString() {
1982 constexpr int64_t kMinSecondsBetweenCalls = 30;
1983 if (absl::Now() < cache_until_) return cached_debug_string_;
1984 std::shared_ptr<Node> snapshot;
1985 {
1986 tf_shared_lock l(mu_);
1987 if (!output_) return cached_debug_string_;
1988 snapshot = output_->Snapshot();
1989 }
1990 // TODO(jsimsa): Populate OptimizationParams.
1991 ModelProto model_proto;
1992 Status s = ModelToProtoHelper(snapshot, &model_proto);
1993 if (s.ok()) {
1994 cached_debug_string_ = model_proto.DebugString();
1995 } else {
1996 LOG(WARNING) << s.error_message();
1997 }
1998 cache_until_ = absl::Now() + absl::Seconds(kMinSecondsBetweenCalls);
1999 return cached_debug_string_;
2000 }
2001
2002 } // namespace model
2003 } // namespace data
2004 } // namespace tensorflow
2005