1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/model.h"
17
18 #include <memory>
19
20 #include "absl/time/clock.h"
21 #include "tensorflow/core/framework/cancellation.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24
25 namespace tensorflow {
26 namespace data {
27 namespace model {
28
29 constexpr int64 Model::kOptimizationPeriodMinMs;
30 constexpr int64 Model::kOptimizationPeriodMaxMs;
31
32 namespace {
33
34 // Helper function for node traversal that doesn't skip any nodes.
IsAnyNode(const std::shared_ptr<Node> node)35 inline bool IsAnyNode(const std::shared_ptr<Node> node) { return true; }
36
37 // Helper function for node traversal that filters out nodes for which
38 // autotuning is disabled.
IsAutotuneNode(const std::shared_ptr<Node> node)39 inline bool IsAutotuneNode(const std::shared_ptr<Node> node) {
40 return node->autotune();
41 }
42
43 // Wrapper for the square function to reduce verbosity.
Square(double x)44 inline double Square(double x) { return x * x; }
45
46 // Collects "essential" parallelism parameters and buffer size parameters in the
47 // tree rooted in the given node. Which parallelism parameters are essential is
48 // determined by the relative processing time spent in the corresponding
49 // transformation. The collected parameters are returned via maps that map node
50 // names to their respective parameters.
CollectParameters(std::shared_ptr<Node> node,const absl::flat_hash_map<string,std::shared_ptr<Parameter>> & parameters,absl::flat_hash_map<string,std::shared_ptr<Parameter>> * parallelism_parameters,absl::flat_hash_map<string,std::shared_ptr<Parameter>> * buffer_size_parameters)51 inline void CollectParameters(
52 std::shared_ptr<Node> node,
53 const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
54 absl::flat_hash_map<string, std::shared_ptr<Parameter>>*
55 parallelism_parameters,
56 absl::flat_hash_map<string, std::shared_ptr<Parameter>>*
57 buffer_size_parameters) {
58 // Parallelism parameter is considered to be essential if the corresponding
59 // transformations's processing time is greater than essential rate times the
60 // average transformation self processing time.
61 constexpr double kEssentialRate = 0.3L;
62
63 absl::flat_hash_map<string, double> processing_times;
64 double processing_time = node->TotalProcessingTime(&processing_times);
65 double uniform_share =
66 processing_time / static_cast<double>(processing_times.size());
67 for (auto& pair : parameters) {
68 if (pair.second->name == kParallelism &&
69 processing_times[pair.first] > kEssentialRate * uniform_share) {
70 parallelism_parameters->insert(pair);
71 } else if (pair.second->name == kBufferSize) {
72 buffer_size_parameters->insert(pair);
73 }
74 }
75 }
76
77 // Applies the gradient descent method once and updates the parameter values. If
78 // the new value is out of the range, bound it within the range between the
79 // minimal and maximum values.
UpdateParameterValues(const absl::flat_hash_map<string,double> & gradients,absl::flat_hash_map<string,std::shared_ptr<Parameter>> * parameters)80 inline void UpdateParameterValues(
81 const absl::flat_hash_map<string, double>& gradients,
82 absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) {
83 // Gradient descent step size.
84 constexpr double kDescentStep = 0.1L;
85 double new_value;
86
87 double max_abs_derivative = 1.0;
88 for (auto& pair : *parameters) {
89 if (std::round(pair.second->value) != pair.second->max) {
90 auto* gradient = gtl::FindOrNull(gradients, pair.first);
91 if (gradient) {
92 max_abs_derivative = std::max(max_abs_derivative, std::abs(*gradient));
93 }
94 }
95 }
96 for (auto& pair : *parameters) {
97 auto* gradient = gtl::FindOrNull(gradients, pair.first);
98 if (gradient) {
99 new_value =
100 pair.second->value - kDescentStep * (*gradient) / max_abs_derivative;
101 // Projection on a feasible interval.
102 if (new_value > pair.second->max) {
103 pair.second->value = pair.second->max;
104 } else if (new_value < pair.second->min) {
105 pair.second->value = pair.second->min;
106 } else {
107 pair.second->value = new_value;
108 }
109 }
110 }
111 }
112
113 // Copies the parameter values (which are for optimization tuning) and updates
114 // the state values (which are for the input pipeline to follow).
UpdateStateValues(absl::flat_hash_map<string,std::shared_ptr<Parameter>> * parameters)115 inline void UpdateStateValues(
116 absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) {
117 VLOG(2) << "Number of tunable parameters: " << parameters->size();
118 for (auto& pair : *parameters) {
119 auto& parameter = pair.second;
120 VLOG(2) << "Setting tunable parameter " << pair.first << " to "
121 << parameter->value;
122 mutex_lock l(*parameter->state->mu);
123 parameter->state->value = parameter->value;
124 parameter->state->cond_var->notify_all();
125 }
126 }
127
128 // The first input of InterleaveMany corresponds to the input dataset whose
129 // elements are used to create the (derived) input datasets whose elements are
130 // interleaved as output.
131 //
132 // TODO(jsimsa): model the first input
133 class InterleaveMany : public Node {
134 public:
135 using Node::Node;
136
~InterleaveMany()137 virtual ~InterleaveMany() {}
138
139 protected:
Clone(std::shared_ptr<Node> output) const140 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
141 TF_SHARED_LOCKS_REQUIRED(mu_) {
142 return std::make_shared<InterleaveMany>(
143 Args{id_, name_, std::move(output)});
144 }
145
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const146 void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
147 const override TF_SHARED_LOCKS_REQUIRED(mu_) {
148 double inherited_input_time;
149 if (output_) {
150 inherited_input_time = (*input_times)[output_->long_name()];
151 } else {
152 inherited_input_time = (*input_times)[kModelInputTimeKey];
153 }
154
155 if (num_inputs() <= 1) {
156 (*input_times)[long_name()] = inherited_input_time;
157 return;
158 }
159 // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
160 // input time for InterleaveMany node to call one of the
161 // `(num_inputs() - 1)` input nodes (except first input) to return an
162 // element. Regardless of the `block_length` parameter of InterleaveMany
163 // node, the average input time for any of the `(num_inputs() - 1)` input
164 // nodes to be called is computed as:
165 double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
166 static_cast<double>(num_inputs() - 1);
167 (*input_times)[long_name()] = input_time;
168 }
169
170 // The output time is the sum of the self processing time and the average
171 // output time of inputs comprising the interleave "cycle".
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const172 void OutputTimeLocked(
173 const absl::flat_hash_map<string, double>& input_times,
174 absl::flat_hash_map<string, double>* gradients,
175 absl::flat_hash_map<string, double>* output_times,
176 absl::flat_hash_map<string, double>* output_time_gradients) const override
177 TF_SHARED_LOCKS_REQUIRED(mu_) {
178 double self_processing_time = SelfProcessingTimeLocked();
179 if (num_inputs() <= 1) {
180 (*output_times)[long_name()] = self_processing_time;
181 if (gradients) {
182 for (const auto& node :
183 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
184 gradients->erase(node->long_name());
185 }
186 }
187 return;
188 }
189
190 double inputs_output_time =
191 (OutputTimeForInputs(*output_times) -
192 (*output_times)[inputs_.front()->long_name()]) /
193 static_cast<double>(num_inputs() - 1);
194 if (gradients) {
195 for (const auto& node :
196 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
197 auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
198 if (gradient) {
199 *gradient /= static_cast<double>(num_inputs() - 1);
200 }
201 }
202
203 (*output_time_gradients)[long_name()] =
204 OutputTimeGradientsForInputs(*output_time_gradients) -
205 (*output_time_gradients)[inputs_.front()->long_name()];
206
207 // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
208 // first input equal to 0 since its output time is excluded from
209 // computations.
210 absl::flat_hash_map<string, std::shared_ptr<Parameter>>
211 first_input_parameters;
212 inputs_.front()->CollectTunableParameters(&first_input_parameters);
213 for (auto& pair : first_input_parameters) {
214 (*gradients)[pair.first] = 0.0L;
215 }
216 }
217 (*output_times)[long_name()] = self_processing_time + inputs_output_time;
218 }
219
220 // The processing time is the sum of the self processing time and the average
221 // processing time of inputs comprising the interleave "cycle".
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)222 void TotalProcessingTimeLocked(
223 absl::flat_hash_map<string, double>* processing_times,
224 absl::flat_hash_map<string, double>* total_processing_times) override
225 TF_SHARED_LOCKS_REQUIRED(mu_) {
226 double self_processing_time = SelfProcessingTimeLocked();
227 if (processing_times) {
228 (*processing_times)[long_name()] = self_processing_time;
229 }
230 if (num_inputs() <= 1) {
231 (*total_processing_times)[long_name()] = self_processing_time;
232 return;
233 }
234 double inputs_processing_time =
235 (TotalProcessingTimeForInputs(*total_processing_times) -
236 (*total_processing_times)[inputs_.front()->long_name()]) /
237 static_cast<double>(num_inputs() - 1);
238 (*total_processing_times)[long_name()] =
239 self_processing_time + inputs_processing_time;
240 }
241
ToProto(ModelProto::Node * node_proto) const242 Status ToProto(ModelProto::Node* node_proto) const {
243 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
244 node_proto->set_node_class(NodeClass::INTERLEAVE_MANY);
245 return Status::OK();
246 }
247 };
248
249 // The first input of AsyncInterleaveMany corresponds to the input dataset whose
250 // elements are used to create the (derived) input datasets whose elements are
251 // interleaved as output.
252 //
253 // TODO(jsimsa): model the first input
254 class AsyncInterleaveMany : public Node {
255 public:
AsyncInterleaveMany(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)256 AsyncInterleaveMany(Node::Args args,
257 std::vector<std::shared_ptr<Parameter>> parameters)
258 : Node(args) {
259 for (auto& parameter : parameters) {
260 parameters_[parameter->name] = std::move(parameter);
261 }
262 }
263
~AsyncInterleaveMany()264 virtual ~AsyncInterleaveMany() {}
265
266 protected:
Clone(std::shared_ptr<Node> output) const267 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
268 TF_SHARED_LOCKS_REQUIRED(mu_) {
269 std::vector<std::shared_ptr<Parameter>> parameters;
270 for (auto& pair : parameters_) {
271 parameters.push_back(pair.second);
272 }
273 return std::make_shared<AsyncInterleaveMany>(
274 Args{id_, name_, std::move(output)}, parameters);
275 }
276
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const277 void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
278 const override TF_SHARED_LOCKS_REQUIRED(mu_) {
279 double inherited_input_time;
280 if (output_) {
281 inherited_input_time = (*input_times)[output_->long_name()];
282 } else {
283 inherited_input_time = (*input_times)[kModelInputTimeKey];
284 }
285
286 if (num_inputs() <= 1) {
287 (*input_times)[long_name()] = inherited_input_time;
288 return;
289 }
290 // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
291 // input time for AsyncInterleaveMany node to call one of the
292 // `(num_inputs() - 1)` input nodes (except first input) to return an
293 // element. Regardless of the `block_length` parameter of
294 // AsyncInterleaveMany node, the average input time for any of the
295 // `(num_inputs() - 1)` input nodes to be called is computed as:
296 double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
297 static_cast<double>(num_inputs() - 1);
298 (*input_times)[long_name()] = input_time;
299 }
300
301 // The output time is the sum of self processing time and expected wait time
302 // from the buffer model estimated using
303 // `ComputeWaitTime(producer_time, consumer_time, parallelism, ...)`, where
304 // `producer_time` is the average output time of inputs comprising the
305 // interleave "cycle" divided by `parallelism`, `consumer_time` is the
306 // `input_time` specified through `input_times` divided by `num_inputs() - 1`,
307 // and if the node has parallelism parameter, then `buffer_size` is derived
308 // from `parallelism`.
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const309 void OutputTimeLocked(
310 const absl::flat_hash_map<string, double>& input_times,
311 absl::flat_hash_map<string, double>* gradients,
312 absl::flat_hash_map<string, double>* output_times,
313 absl::flat_hash_map<string, double>* output_time_gradients) const override
314 TF_SHARED_LOCKS_REQUIRED(mu_) {
315 double self_processing_time = SelfProcessingTimeLocked();
316 if (num_inputs() <= 1) {
317 (*output_times)[long_name()] = self_processing_time;
318 if (gradients) {
319 for (const auto& node :
320 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
321 gradients->erase(node->long_name());
322 }
323 }
324 return;
325 }
326
327 double output_time, wait_time, consumer_time, producer_time;
328 double input_time = input_times.at(long_name());
329 consumer_time = input_time / static_cast<double>(num_inputs() - 1);
330 double parallelism = num_inputs() - 1; // default to cycle length
331 auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
332 if (parameter) {
333 parallelism = std::min(parallelism, (*parameter)->value);
334 }
335 double output_time_for_inputs =
336 OutputTimeForInputs(*output_times) -
337 (*output_times)[inputs_.front()->long_name()];
338 producer_time = output_time_for_inputs /
339 static_cast<double>(num_inputs() - 1) / parallelism;
340
341 if (gradients) {
342 double producer_time_der = 0.0L;
343 double consumer_time_der = 0.0L;
344 double buffer_size_der = 0.0L;
345 wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
346 &producer_time_der, &consumer_time_der,
347 &buffer_size_der);
348 double inputs_time_der_sum =
349 OutputTimeGradientsForInputs(*output_time_gradients);
350 (*output_time_gradients)[long_name()] =
351 consumer_time_der +
352 producer_time_der * inputs_time_der_sum / parallelism;
353
354 for (const auto& node :
355 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
356 auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
357 if (gradient) {
358 *gradient *= (producer_time_der /
359 static_cast<double>(num_inputs() - 1) / parallelism);
360 }
361 }
362
363 // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
364 // first input equal to 0 since its output time is excluded from
365 // computations.
366 absl::flat_hash_map<string, std::shared_ptr<Parameter>>
367 first_input_parameters;
368 inputs_.front()->CollectTunableParameters(&first_input_parameters);
369 for (auto& pair : first_input_parameters) {
370 (*gradients)[pair.first] = 0.0L;
371 }
372 // Add derivative w.r.t. own parallelism parameter.
373 if (parameter && (*parameter)->state->tunable) {
374 (*gradients)[long_name()] =
375 buffer_size_der - producer_time_der * producer_time / parallelism;
376 }
377 } else {
378 wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
379 /*producer_time_derivative=*/nullptr,
380 /*consumer_time_derivative=*/nullptr,
381 /*buffer_size_derivative=*/nullptr);
382 }
383 output_time = self_processing_time + wait_time;
384 (*output_times)[long_name()] = output_time;
385 }
386
387 // The processing time is the sum of the self processing time and the average
388 // processing time of inputs comprising the interleave "cycle".
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)389 void TotalProcessingTimeLocked(
390 absl::flat_hash_map<string, double>* processing_times,
391 absl::flat_hash_map<string, double>* total_processing_times) override
392 TF_SHARED_LOCKS_REQUIRED(mu_) {
393 double self_processing_time = SelfProcessingTimeLocked();
394 if (processing_times) {
395 (*processing_times)[long_name()] = self_processing_time;
396 }
397 if (num_inputs() <= 1) {
398 (*total_processing_times)[long_name()] = self_processing_time;
399 return;
400 }
401 double inputs_processing_time =
402 (TotalProcessingTimeForInputs(*total_processing_times) -
403 (*total_processing_times)[inputs_.front()->long_name()]) /
404 static_cast<double>(num_inputs() - 1);
405 (*total_processing_times)[long_name()] =
406 self_processing_time + inputs_processing_time;
407 }
408
MaximumBufferedBytes() const409 double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
410 double result = 0;
411 auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
412 if (parameter) {
413 result += (*parameter)->value * AverageBufferedElementSize();
414 }
415 return result;
416 }
417
ToProto(ModelProto::Node * node_proto) const418 Status ToProto(ModelProto::Node* node_proto) const {
419 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
420 node_proto->set_node_class(NodeClass::ASYNC_INTERLEAVE_MANY);
421 return Status::OK();
422 }
423 };
424
425 class KnownRatio : public Node {
426 public:
KnownRatio(Node::Args args,double ratio)427 KnownRatio(Node::Args args, double ratio) : Node(args), ratio_(ratio) {}
428
~KnownRatio()429 virtual ~KnownRatio() {}
430
431 protected:
Clone(std::shared_ptr<Node> output) const432 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
433 TF_SHARED_LOCKS_REQUIRED(mu_) {
434 return std::make_shared<KnownRatio>(Args{id_, name_, std::move(output)},
435 ratio_);
436 }
437
438 // The input time is the sum of inherited input time and self processing time,
439 // divided by `ratio_`.
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const440 void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
441 const override TF_SHARED_LOCKS_REQUIRED(mu_) {
442 double inherited_input_time;
443 if (output_) {
444 inherited_input_time = (*input_times)[output_->long_name()];
445 } else {
446 inherited_input_time = (*input_times)[kModelInputTimeKey];
447 }
448
449 if (ratio_ == 0) {
450 (*input_times)[long_name()] = inherited_input_time;
451 return;
452 }
453 double input_time =
454 (inherited_input_time + SelfProcessingTimeLocked()) / ratio_;
455 (*input_times)[long_name()] = input_time;
456 }
457
458 // The output time is the sum of the self processing time and the product of
459 // `ratio_` and the sum of output times of inputs.
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const460 void OutputTimeLocked(
461 const absl::flat_hash_map<string, double>& input_times,
462 absl::flat_hash_map<string, double>* gradients,
463 absl::flat_hash_map<string, double>* output_times,
464 absl::flat_hash_map<string, double>* output_time_gradients) const override
465 TF_SHARED_LOCKS_REQUIRED(mu_) {
466 double self_processing_time = SelfProcessingTimeLocked();
467 if (ratio_ == 0) {
468 (*output_times)[long_name()] = self_processing_time;
469 if (gradients) {
470 for (const auto& node :
471 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
472 gradients->erase(node->long_name());
473 }
474 }
475 return;
476 }
477 if (gradients) {
478 for (const auto& node :
479 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
480 auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
481 if (gradient) {
482 *gradient *= ratio_;
483 }
484 }
485 (*output_time_gradients)[long_name()] =
486 OutputTimeGradientsForInputs(*output_time_gradients);
487 }
488 double inputs_output_time = ratio_ * OutputTimeForInputs(*output_times);
489 (*output_times)[long_name()] = self_processing_time + inputs_output_time;
490 }
491
492 // The processing time is the sum of the self processing time and the product
493 // of `ratio_` and the sum of processing times of inputs.
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)494 void TotalProcessingTimeLocked(
495 absl::flat_hash_map<string, double>* processing_times,
496 absl::flat_hash_map<string, double>* total_processing_times) override
497 TF_SHARED_LOCKS_REQUIRED(mu_) {
498 double self_processing_time = SelfProcessingTimeLocked();
499 if (processing_times) {
500 (*processing_times)[long_name()] = self_processing_time;
501 }
502 if (ratio_ == 0) {
503 (*total_processing_times)[long_name()] = self_processing_time;
504 return;
505 }
506 double inputs_processing_time =
507 ratio_ * TotalProcessingTimeForInputs(*total_processing_times);
508 (*total_processing_times)[long_name()] =
509 self_processing_time + inputs_processing_time;
510 }
511
ToProto(ModelProto::Node * node_proto) const512 Status ToProto(ModelProto::Node* node_proto) const {
513 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
514 node_proto->set_node_class(NodeClass::KNOWN_RATIO);
515 node_proto->set_ratio(ratio_);
516 return Status::OK();
517 }
518
519 private:
520 const double ratio_;
521 };
522
523 class AsyncKnownRatio : public Node {
524 public:
AsyncKnownRatio(Node::Args args,double ratio,double memory_ratio,std::vector<std::shared_ptr<Parameter>> parameters)525 AsyncKnownRatio(Node::Args args, double ratio, double memory_ratio,
526 std::vector<std::shared_ptr<Parameter>> parameters)
527 : Node(args), ratio_(ratio), memory_ratio_(memory_ratio) {
528 for (auto& parameter : parameters) {
529 parameters_[parameter->name] = std::move(parameter);
530 }
531 }
532
~AsyncKnownRatio()533 virtual ~AsyncKnownRatio() {}
534
535 protected:
Clone(std::shared_ptr<Node> output) const536 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
537 TF_SHARED_LOCKS_REQUIRED(mu_) {
538 std::vector<std::shared_ptr<Parameter>> parameters;
539 for (auto& pair : parameters_) {
540 parameters.push_back(pair.second);
541 }
542 return std::make_shared<AsyncKnownRatio>(
543 Args{id_, name_, std::move(output)}, ratio_, memory_ratio_, parameters);
544 }
545
546 // The input time is the sum of inherited input time and parallelism adjusted
547 // self processing time, divided by `ratio_`.
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const548 void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
549 const override TF_SHARED_LOCKS_REQUIRED(mu_) {
550 double inherited_input_time;
551 if (output_) {
552 inherited_input_time = (*input_times)[output_->long_name()];
553 } else {
554 inherited_input_time = (*input_times)[kModelInputTimeKey];
555 }
556 double parallelism = 1.0;
557 auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
558 if (parallelism_parameter) {
559 parallelism = (*parallelism_parameter)->value;
560 }
561
562 if (ratio_ == 0.0) {
563 (*input_times)[long_name()] =
564 inherited_input_time + SelfProcessingTimeLocked() / parallelism;
565 return;
566 }
567 double input_time =
568 (inherited_input_time + SelfProcessingTimeLocked() / parallelism) /
569 ratio_;
570 (*input_times)[long_name()] = input_time;
571 }
572
573 // The output time is the sum of parallelism adjusted self processing time and
574 // expected wait time from the buffer model estimated using
575 // `ComputeWaitTime(producer_time, consumer_time, parallelism, ...)`, where
576 // `producer_time` is the product of `ratio_` and the sum of output times of
577 // inputs, `consumer_time` is the product of `ratio_` and the `input_time`
578 // specified through `input_times` (since for each element stored in the
579 // buffer, the inputs need to be called `ratio_` times), and if the node has
580 // parallelism parameter, then `buffer_size` is derived from `parallelism`.
581 //
582 // Current implementation assumes that there is at most 1 parameter per node.
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const583 void OutputTimeLocked(
584 const absl::flat_hash_map<string, double>& input_times,
585 absl::flat_hash_map<string, double>* gradients,
586 absl::flat_hash_map<string, double>* output_times,
587 absl::flat_hash_map<string, double>* output_time_gradients) const override
588 TF_SHARED_LOCKS_REQUIRED(mu_) {
589 double parallelism = 1.0;
590 double buffer_size = 0.0;
591 auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
592 auto* buffer_size_parameter = gtl::FindOrNull(parameters_, kBufferSize);
593 if (parallelism_parameter) {
594 parallelism = (*parallelism_parameter)->value;
595 if (ratio_ == 0) {
596 buffer_size = parallelism;
597 } else {
598 // Currently, MapAndBatch is the only transformation creates
599 // AsyncKnownRatio nodes with ratio >= 1. For MapAndBatch, we create
600 // `parallelism` threads to apply the function on elements from input
601 // dataset, while one element in the buffer actually corresponds to
602 // `ratio_` elements from input dataset. So we adjust the `buffer_size`
603 // by dividing `ratio_`.
604 buffer_size = parallelism / ratio_;
605 }
606 } else if (buffer_size_parameter) {
607 buffer_size = (*buffer_size_parameter)->value;
608 }
609 double self_processing_time = SelfProcessingTimeLocked();
610 double output_time, wait_time, consumer_time, producer_time;
611 double input_time = input_times.at(long_name());
612
613 if (ratio_ == 0) {
614 consumer_time = input_time;
615 producer_time = 0.0L;
616 if (gradients) {
617 for (const auto& node :
618 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
619 gradients->erase(node->long_name());
620 }
621
622 double producer_time_der = 0.0L;
623 double consumer_time_der = 0.0L;
624 double buffer_size_der = 0.0L;
625 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
626 &producer_time_der, &consumer_time_der,
627 &buffer_size_der);
628 (*output_time_gradients)[long_name()] = consumer_time_der;
629 if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
630 (*gradients)[long_name()] = -(1.0L + consumer_time_der) *
631 self_processing_time /
632 Square(parallelism) +
633 buffer_size_der;
634 } else if (buffer_size_parameter &&
635 (*buffer_size_parameter)->state->tunable) {
636 (*gradients)[long_name()] = buffer_size_der;
637 }
638 } else {
639 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
640 /*producer_time_derivative=*/nullptr,
641 /*consumer_time_derivative=*/nullptr,
642 /*buffer_size_derivative=*/nullptr);
643 }
644 output_time = self_processing_time / parallelism + wait_time;
645 (*output_times)[long_name()] = output_time;
646 return;
647 }
648
649 consumer_time = input_time * ratio_;
650 producer_time = ratio_ * OutputTimeForInputs(*output_times);
651 if (gradients) {
652 double producer_time_der = 0.0L;
653 double consumer_time_der = 0.0L;
654 double buffer_size_der = 0.0L;
655 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
656 &producer_time_der, &consumer_time_der,
657 &buffer_size_der);
658 double inputs_time_der_sum =
659 OutputTimeGradientsForInputs(*output_time_gradients);
660 (*output_time_gradients)[long_name()] =
661 consumer_time_der + producer_time_der * inputs_time_der_sum;
662
663 for (const auto& node :
664 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
665 auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
666 if (gradient) {
667 *gradient *= (ratio_ * producer_time_der);
668 }
669 }
670
671 // Add derivative w.r.t. own parameter if it's tunable.
672 if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
673 (*gradients)[long_name()] = buffer_size_der / ratio_ -
674 (1.0L + consumer_time_der +
675 producer_time_der * inputs_time_der_sum) *
676 self_processing_time /
677 Square(parallelism);
678 } else if (buffer_size_parameter &&
679 (*buffer_size_parameter)->state->tunable) {
680 (*gradients)[long_name()] = buffer_size_der;
681 }
682 } else {
683 wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
684 /*producer_time_derivative=*/nullptr,
685 /*consumer_time_derivative=*/nullptr,
686 /*buffer_size_derivative=*/nullptr);
687 }
688 output_time = self_processing_time / parallelism + wait_time;
689 (*output_times)[long_name()] = output_time;
690 }
691
692 // The processing time is the sum of the self processing time and the product
693 // of `ratio_` and the sum of processing times of inputs.
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)694 void TotalProcessingTimeLocked(
695 absl::flat_hash_map<string, double>* processing_times,
696 absl::flat_hash_map<string, double>* total_processing_times) override
697 TF_SHARED_LOCKS_REQUIRED(mu_) {
698 double self_processing_time = SelfProcessingTimeLocked();
699 if (processing_times) {
700 (*processing_times)[long_name()] = self_processing_time;
701 }
702 if (ratio_ == 0) {
703 (*total_processing_times)[long_name()] = self_processing_time;
704 return;
705 }
706 double inputs_processing_time =
707 ratio_ * TotalProcessingTimeForInputs(*total_processing_times);
708 (*total_processing_times)[long_name()] =
709 self_processing_time + inputs_processing_time;
710 }
711
MaximumBufferedBytes() const712 double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
713 double result = 0;
714 auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
715 if (!parameter) {
716 parameter = gtl::FindOrNull(parameters_, kParallelism);
717 }
718
719 if (parameter) {
720 if (memory_ratio_ == 0) {
721 result += (*parameter)->value * AverageBufferedElementSize();
722 } else {
723 // The estimation is currently not accurate for MapAndBatchDataset for
724 // the maximum buffer size does not match `num_parallel_calls`
725 // parameter.
726 result +=
727 (*parameter)->value * AverageBufferedElementSize() / memory_ratio_;
728 }
729 }
730 return result;
731 }
732
ToProto(ModelProto::Node * node_proto) const733 Status ToProto(ModelProto::Node* node_proto) const {
734 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
735 node_proto->set_node_class(NodeClass::ASYNC_KNOWN_RATIO);
736 node_proto->set_ratio(ratio_);
737 node_proto->set_memory_ratio(memory_ratio_);
738 return Status::OK();
739 }
740
741 private:
742 // Identifies how many input elements need to be created to construct an
743 // element for the dataset.
744 //
745 // Currently the value is 1 for PrefetchDataset and ParallelMapDataset,
746 // batch_size for MapAndBatchDataset and ParallelBatchDataset.
747 const double ratio_;
748 // For parallelism nodes, identifies how many parallelism calls are introduced
749 // by one buffered element. The value is defined to correctly estimate RAM
750 // budget bound with given num_parallel_calls (or buffer_size) combined with
751 // the estimated average size of buffered elements.
752 const double memory_ratio_;
753 };
754
755 class UnknownRatio : public Node {
756 public:
757 using Node::Node;
758
~UnknownRatio()759 virtual ~UnknownRatio() {}
760
761 protected:
Clone(std::shared_ptr<Node> output) const762 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
763 TF_SHARED_LOCKS_REQUIRED(mu_) {
764 return std::make_shared<UnknownRatio>(Args{id_, name_, std::move(output)});
765 }
766
767 // The input time is the sum of inherited input time and self processing time,
768 // divided by the ratio estimate.
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const769 void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
770 const override TF_SHARED_LOCKS_REQUIRED(mu_) {
771 double inherited_input_time;
772 if (output_) {
773 inherited_input_time = (*input_times)[output_->long_name()];
774 } else {
775 inherited_input_time = (*input_times)[kModelInputTimeKey];
776 }
777
778 if (num_elements_ == 0 || inputs_.empty() ||
779 inputs_.front()->num_elements() == 0) {
780 (*input_times)[long_name()] = inherited_input_time;
781 return;
782 }
783 std::shared_ptr<Node> input = inputs_.front();
784 double ratio = static_cast<double>(input->num_elements()) /
785 static_cast<double>(num_elements_);
786 double input_time =
787 (inherited_input_time + SelfProcessingTimeLocked()) / ratio;
788 (*input_times)[long_name()] = input_time;
789 }
790
791 // The output time is the sum of the self processing time and the product of
792 // the ratio estimate and the sum of output times of inputs.
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const793 void OutputTimeLocked(
794 const absl::flat_hash_map<string, double>& input_times,
795 absl::flat_hash_map<string, double>* gradients,
796 absl::flat_hash_map<string, double>* output_times,
797 absl::flat_hash_map<string, double>* output_time_gradients) const override
798 TF_SHARED_LOCKS_REQUIRED(mu_) {
799 double self_processing_time = SelfProcessingTimeLocked();
800 if (num_elements_ == 0 || inputs_.empty() ||
801 inputs_.front()->num_elements() == 0) {
802 (*output_times)[long_name()] = self_processing_time;
803 if (gradients) {
804 for (const auto& node :
805 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
806 gradients->erase(node->long_name());
807 }
808 }
809 return;
810 }
811 // TODO(jsimsa): The current implementation assumes that the number of input
812 // elements consumed per output is the same across all inputs.
813 double ratio = static_cast<double>(inputs_.front()->num_elements()) /
814 static_cast<double>(num_elements_);
815 if (gradients) {
816 for (const auto& node :
817 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
818 auto* gradient = gtl::FindOrNull(*gradients, node->long_name());
819 if (gradient) {
820 *gradient *= ratio;
821 }
822 }
823 (*output_time_gradients)[long_name()] =
824 OutputTimeGradientsForInputs(*output_time_gradients);
825 }
826 double inputs_output_time = ratio * OutputTimeForInputs(*output_times);
827 (*output_times)[long_name()] = self_processing_time + inputs_output_time;
828 }
829
830 // The processing time is the sum of the self processing time and the product
831 // of the ratio estimate and the sum of processing times of inputs.
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)832 void TotalProcessingTimeLocked(
833 absl::flat_hash_map<string, double>* processing_times,
834 absl::flat_hash_map<string, double>* total_processing_times) override
835 TF_SHARED_LOCKS_REQUIRED(mu_) {
836 double self_processing_time = SelfProcessingTimeLocked();
837 if (processing_times) {
838 (*processing_times)[long_name()] = self_processing_time;
839 }
840 if (inputs_.empty() || num_elements_ == 0) {
841 (*total_processing_times)[long_name()] = self_processing_time;
842 return;
843 }
844 // TODO(jsimsa): The current implementation assumes that the number of input
845 // elements consumed per output is the same across all inputs.
846 std::shared_ptr<Node> input = inputs_.front();
847 double ratio = static_cast<double>(input->num_elements()) /
848 static_cast<double>(num_elements_);
849 double inputs_processing_time =
850 ratio * TotalProcessingTimeForInputs(*total_processing_times);
851 (*total_processing_times)[long_name()] =
852 self_processing_time + inputs_processing_time;
853 }
854
ToProto(ModelProto::Node * node_proto) const855 Status ToProto(ModelProto::Node* node_proto) const {
856 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
857 node_proto->set_node_class(NodeClass::UNKNOWN_RATIO);
858 return Status::OK();
859 }
860 };
861
862 class Unknown : public Node {
863 public:
864 using Node::Node;
865
~Unknown()866 virtual ~Unknown() {}
867
868 protected:
Clone(std::shared_ptr<Node> output) const869 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
870 TF_SHARED_LOCKS_REQUIRED(mu_) {
871 return std::make_shared<Unknown>(Args{id_, name_, std::move(output)});
872 }
873
874 // The input time is the inherited input time.
InputTimeLocked(absl::flat_hash_map<string,double> * input_times) const875 void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
876 const override TF_SHARED_LOCKS_REQUIRED(mu_) {
877 double inherited_input_time;
878 if (output_) {
879 inherited_input_time = (*input_times)[output_->long_name()];
880 } else {
881 inherited_input_time = (*input_times)[kModelInputTimeKey];
882 }
883 (*input_times)[long_name()] = inherited_input_time;
884 }
885
886 // The output time is the sum of output times of inputs.
OutputTimeLocked(const absl::flat_hash_map<string,double> & input_times,absl::flat_hash_map<string,double> * gradients,absl::flat_hash_map<string,double> * output_times,absl::flat_hash_map<string,double> * output_time_gradients) const887 void OutputTimeLocked(
888 const absl::flat_hash_map<string, double>& input_times,
889 absl::flat_hash_map<string, double>* gradients,
890 absl::flat_hash_map<string, double>* output_times,
891 absl::flat_hash_map<string, double>* output_time_gradients) const override
892 TF_SHARED_LOCKS_REQUIRED(mu_) {
893 (*output_times)[long_name()] = OutputTimeForInputs(*output_times);
894 if (gradients) {
895 (*output_time_gradients)[long_name()] =
896 OutputTimeGradientsForInputs(*output_time_gradients);
897 }
898 }
899
900 // The processing time is the sum of processing times of inputs.
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)901 void TotalProcessingTimeLocked(
902 absl::flat_hash_map<string, double>* processing_times,
903 absl::flat_hash_map<string, double>* total_processing_times) override
904 TF_SHARED_LOCKS_REQUIRED(mu_) {
905 if (processing_times) {
906 (*processing_times)[long_name()] = SelfProcessingTimeLocked();
907 }
908 (*total_processing_times)[long_name()] =
909 TotalProcessingTimeForInputs(*total_processing_times);
910 }
911
ToProto(ModelProto::Node * node_proto) const912 Status ToProto(ModelProto::Node* node_proto) const {
913 TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
914 node_proto->set_node_class(NodeClass::UNKNOWN);
915 return Status::OK();
916 }
917 };
918
919 } // namespace
920
921 thread_local int64 Node::work_start_;
922
MakeParameter(const string & name,std::shared_ptr<SharedState> state,double min,double max)923 std::shared_ptr<Parameter> MakeParameter(const string& name,
924 std::shared_ptr<SharedState> state,
925 double min, double max) {
926 return std::make_shared<Parameter>(name, state, min, max);
927 }
928
MakeInterleaveManyNode(Node::Args args)929 std::shared_ptr<Node> MakeInterleaveManyNode(Node::Args args) {
930 return std::make_shared<InterleaveMany>(std::move(args));
931 }
932
MakeAsyncInterleaveManyNode(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)933 std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
934 Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters) {
935 return std::make_shared<AsyncInterleaveMany>(std::move(args),
936 std::move(parameters));
937 }
938
MakeKnownRatioNode(Node::Args args,double ratio)939 std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio) {
940 return std::make_shared<KnownRatio>(std::move(args), ratio);
941 }
942
MakeAsyncKnownRatioNode(Node::Args args,double ratio,double memory_ratio,std::vector<std::shared_ptr<Parameter>> parameters)943 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
944 Node::Args args, double ratio, double memory_ratio,
945 std::vector<std::shared_ptr<Parameter>> parameters) {
946 return std::make_shared<AsyncKnownRatio>(std::move(args), ratio, memory_ratio,
947 std::move(parameters));
948 }
949
MakeAsyncKnownRatioNode(Node::Args args,double ratio,std::vector<std::shared_ptr<Parameter>> parameters)950 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
951 Node::Args args, double ratio,
952 std::vector<std::shared_ptr<Parameter>> parameters) {
953 return MakeAsyncKnownRatioNode(std::move(args), /*ratio=*/ratio,
954 /*memory_ratio=*/ratio, std::move(parameters));
955 }
956
MakeSourceNode(Node::Args args)957 std::shared_ptr<Node> MakeSourceNode(Node::Args args) {
958 return MakeKnownRatioNode(std::move(args), 0);
959 }
960
MakeUnknownRatioNode(Node::Args args)961 std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args) {
962 return std::make_shared<UnknownRatio>(std::move(args));
963 }
964
MakeUnknownNode(Node::Args args)965 std::shared_ptr<Node> MakeUnknownNode(Node::Args args) {
966 return std::make_shared<Unknown>(std::move(args));
967 }
968
ComputeWaitTime(const double & producer_time,const double & consumer_time,const double & buffer_size,double * producer_time_derivative,double * consumer_time_derivative,double * buffer_size_derivative)969 double Node::ComputeWaitTime(const double& producer_time,
970 const double& consumer_time,
971 const double& buffer_size,
972 double* producer_time_derivative,
973 double* consumer_time_derivative,
974 double* buffer_size_derivative) {
975 // If we set x=`consumer_time`, y=`producer_time`, n=`buffer_size`,
976 // p=`p_buffer_empty`, T=`wait_time`, then we have:
977 // if y = 0, then p = 0;
978 // elif x = 0, then p = 1;
979 // elif x = y, then p = 1 / (n+1);
980 // else p = [1 - x/y] / [1 - power(x/y, n+1)].
981 //
982 // We also have T = p * y, and derivatives of T w.r.t. x, y, n are computed:
983 // dT/dx = dp/dx * y,
984 // dT/dy = p + dp/dy * y,
985 // dT/dn = dp/dn * y.
986 // Then the remaining work is to compute dp/dx, dp/dy, dp/dn by considering
987 // different cases and substitute the values into above formulas.
988
989 // Case 1: if producer is infinitely fast. The buffer will always be full.
990 // Wait time will always be 0.
991 if (producer_time == 0) {
992 if (producer_time_derivative) {
993 // Note a common error is `*producer_time_derivative = 0` since p=0 on the
994 // line y=0 doesn't imply dp/dy = 0 there. Actually to compute dp/dy at
995 // (x,0), we need to consider lim_{dy->0+} [p(x,dy)-p(x,0)] / dy, where
996 // p(x,0)=0 and p(x,dy) = [1 - x/dy] / [1 - power(x/dy, n+1)].
997 if (buffer_size == 0 || consumer_time == 0) {
998 *producer_time_derivative = 1.0L;
999 } else {
1000 *producer_time_derivative = 0.0L;
1001 }
1002 }
1003 if (consumer_time_derivative) {
1004 *consumer_time_derivative = 0.0L;
1005 }
1006 if (buffer_size_derivative) {
1007 *buffer_size_derivative = 0.0L;
1008 }
1009 return 0.0L;
1010 }
1011
1012 // Case 2: if consumer is infinitely fast. Wait time is always the time to
1013 // produce an output.
1014 if (consumer_time == 0) {
1015 if (producer_time_derivative) {
1016 *producer_time_derivative = 1.0L;
1017 }
1018 if (consumer_time_derivative) {
1019 // Note a common error is `*consumer_time_derivative = 0` since p=1 on the
1020 // line x=0 doesn't imply dp/dx = 0 there. Actually to compute dp/dx at
1021 // (0,y), we need to consider lim_{dx->0+} [p(dx,y)-p(0,y)] / dx, where
1022 // p(0,y)=1, p(dx,y) = [1 - dx/y] / [1 - power(dx/y, n+1)] if y!=0.
1023 if (buffer_size == 0) {
1024 *consumer_time_derivative = 0.0L;
1025 } else {
1026 *consumer_time_derivative = -1.0L;
1027 }
1028 }
1029 if (buffer_size_derivative) {
1030 *buffer_size_derivative = 0.0L;
1031 }
1032 return producer_time;
1033 }
1034
1035 // Case 3: the consumer and the producer are equally fast. Expected wait time
1036 // decreases linearly with the size of the buffer.
1037 if (consumer_time == producer_time) {
1038 const double p_buffer_empty = 1.0L / (buffer_size + 1.0L);
1039 const double p_buffer_empty_der =
1040 -buffer_size / (2.0L * buffer_size + 2.0L);
1041 if (producer_time_derivative) {
1042 // Note a common error is `*producer_time_derivative = p_buffer_empty`
1043 // since p=1/(n+1) on the line x=y doesn't imply dp/dy = 0 there. Actually
1044 // to compute dp/dy at (y,y), we need to consider
1045 // lim_{dy->0} [p(y,y+dy)-p(y,y)] / dy, where p(y,y)=1/(n+1),
1046 // p(y,y+dy) = [1 - y/(y+dy)] / [1 - power(y/(y+dy), n+1)].
1047 *producer_time_derivative = p_buffer_empty - p_buffer_empty_der;
1048 }
1049 if (consumer_time_derivative) {
1050 // Note a common error is `*consumer_time_derivative = 0` since
1051 // p=1/(n+1) on the line x=y doesn't imply dp/dx = 0 there. Actually to
1052 // compute dp/dx at (x,x), we need to consider
1053 // lim_{dx->0} [p(x+dx,x)-p(x,x)] / dx, where p(x,x)=1/(n+1),
1054 // p(x+dx,x) = [1 - (x+dx)/x] / [1 - power((x+dx)/x, n+1)].
1055 *consumer_time_derivative = p_buffer_empty_der;
1056 }
1057 if (buffer_size_derivative) {
1058 *buffer_size_derivative = -producer_time / Square(buffer_size + 1.0L);
1059 }
1060 return p_buffer_empty * producer_time;
1061 }
1062
1063 // Case 4: the consumer is slower than the producer and neither is infinitely
1064 // fast. Case 4 and Case 5 actually follow same formula. Separate them for
1065 // numerical computation reasons.
1066 if (consumer_time > producer_time) {
1067 const double ratio = producer_time / consumer_time;
1068 const double ratio_pow = std::pow(ratio, buffer_size);
1069 const double p_buffer_empty =
1070 ratio_pow * (1.0L - ratio) / (1.0L - ratio * ratio_pow);
1071 const double p_buffer_empty_der =
1072 (buffer_size - (buffer_size + 1.0L) * ratio + ratio_pow * ratio) *
1073 ratio_pow / ratio / Square(1.0L - ratio_pow * ratio);
1074 if (producer_time_derivative) {
1075 *producer_time_derivative = p_buffer_empty + p_buffer_empty_der * ratio;
1076 }
1077 if (consumer_time_derivative) {
1078 *consumer_time_derivative = -p_buffer_empty_der * Square(ratio);
1079 }
1080 if (buffer_size_derivative) {
1081 *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1082 std::log(ratio) * producer_time;
1083 }
1084 return p_buffer_empty * producer_time;
1085 }
1086
1087 // Case 5: the producer is slower than the consumer and neither is infinitely
1088 // fast.
1089 const double ratio = consumer_time / producer_time;
1090 const double ratio_pow = std::pow(ratio, buffer_size);
1091 const double p_buffer_empty = (1.0L - ratio) / (1.0L - ratio_pow * ratio);
1092 const double p_buffer_empty_der =
1093 ((buffer_size + 1.0L - buffer_size * ratio) * ratio_pow - 1.0L) /
1094 Square(1.0L - ratio_pow * ratio);
1095 if (producer_time_derivative) {
1096 *producer_time_derivative = p_buffer_empty - p_buffer_empty_der * ratio;
1097 }
1098 if (consumer_time_derivative) {
1099 *consumer_time_derivative = p_buffer_empty_der;
1100 }
1101 if (buffer_size_derivative) {
1102 *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1103 ratio_pow * ratio * std::log(ratio) *
1104 producer_time;
1105 }
1106 return p_buffer_empty * producer_time;
1107 }
1108
CollectTunableParameters(absl::flat_hash_map<string,std::shared_ptr<Parameter>> * parameters) const1109 void Node::CollectTunableParameters(
1110 absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const {
1111 tf_shared_lock l(mu_);
1112 // Collect tunable parameters from the leaves of the nodes tree to the root.
1113 for (const auto& node :
1114 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1115 tf_shared_lock l(node->mu_);
1116 node->CollectTunableParametersHelper(parameters);
1117 }
1118 CollectTunableParametersHelper(parameters);
1119 }
1120
DebugString() const1121 string Node::DebugString() const {
1122 absl::flat_hash_map<string, string> debug_strings;
1123 tf_shared_lock l(mu_);
1124 // Build up the debug string from the leaves of the nodes tree to the root.
1125 for (const auto& node :
1126 CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1127 tf_shared_lock l(node->mu_);
1128 node->DebugStringHelper(&debug_strings);
1129 }
1130 DebugStringHelper(&debug_strings);
1131
1132 return debug_strings[long_name()];
1133 }
1134
FlushMetrics()1135 void Node::FlushMetrics() {
1136 if (!record_metrics_) {
1137 return;
1138 }
1139 metrics_.record_bytes_consumed(bytes_consumed_);
1140 metrics_.record_bytes_produced(bytes_produced_);
1141 metrics_.record_num_elements(num_elements_);
1142 }
1143
OutputTime(absl::flat_hash_map<string,double> * input_times,absl::flat_hash_map<string,double> * gradients) const1144 double Node::OutputTime(absl::flat_hash_map<string, double>* input_times,
1145 absl::flat_hash_map<string, double>* gradients) const {
1146 // To store the output time gradient w.r.t. input time (if `gradients` is not
1147 // `nullptr`) and the output time for each node.
1148 absl::flat_hash_map<string, double> output_time_gradients, output_times;
1149 tf_shared_lock l(mu_);
1150 auto nodes = CollectNodes(TraversalOrder::BFS, IsAutotuneNode);
1151
1152 // Computes and stores input time for each node from the root to leaves of the
1153 // nodes tree.
1154 InputTimeLocked(input_times);
1155 for (const auto& node : nodes) {
1156 tf_shared_lock l(node->mu_);
1157 node->InputTimeLocked(input_times);
1158 }
1159
1160 std::reverse(nodes.begin(), nodes.end());
1161 // Computes and stores the output time and output time gradient w.r.t. input
1162 // time (if `gradients` is not `nullptr`) for each node from leaves of the
1163 // nodes tree to the root.
1164 for (const auto& node : nodes) {
1165 tf_shared_lock l(node->mu_);
1166 node->OutputTimeLocked(*input_times, gradients, &output_times,
1167 &output_time_gradients);
1168 }
1169 OutputTimeLocked(*input_times, gradients, &output_times,
1170 &output_time_gradients);
1171
1172 return output_times[long_name()];
1173 }
1174
Snapshot() const1175 std::shared_ptr<Node> Node::Snapshot() const {
1176 NodePairList node_pairs;
1177 auto result = SnapshotHelper(nullptr, &node_pairs);
1178
1179 while (!node_pairs.empty()) {
1180 auto node_pair = node_pairs.front();
1181 node_pairs.pop_front();
1182 std::shared_ptr<Node> current = node_pair.first,
1183 cloned_output = node_pair.second;
1184 cloned_output->add_input(
1185 current->SnapshotHelper(cloned_output, &node_pairs));
1186 }
1187 return result;
1188 }
1189
SelfProcessingTime() const1190 double Node::SelfProcessingTime() const {
1191 tf_shared_lock l(mu_);
1192 return SelfProcessingTimeLocked();
1193 }
1194
TotalBufferedBytes() const1195 double Node::TotalBufferedBytes() const {
1196 absl::flat_hash_map<string, double> total_bytes;
1197 tf_shared_lock l(mu_);
1198 // Compute total buffered bytes from the leaves of the nodes tree to the root.
1199 for (const auto& node :
1200 CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1201 tf_shared_lock l(node->mu_);
1202 node->TotalBufferedBytesHelper(&total_bytes);
1203 }
1204 TotalBufferedBytesHelper(&total_bytes);
1205
1206 return total_bytes[long_name()];
1207 }
1208
TotalMaximumBufferedBytes() const1209 double Node::TotalMaximumBufferedBytes() const {
1210 absl::flat_hash_map<string, double> total_bytes;
1211 tf_shared_lock l(mu_);
1212 // Compute total maximum buffered bytes from the leaves of the nodes tree
1213 // to the root.
1214 for (const auto& node :
1215 CollectNodes(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1216 tf_shared_lock l(node->mu_);
1217 node->TotalMaximumBufferedBytesHelper(&total_bytes);
1218 }
1219 TotalMaximumBufferedBytesHelper(&total_bytes);
1220
1221 return total_bytes[long_name()];
1222 }
1223
TotalProcessingTime(absl::flat_hash_map<string,double> * processing_times)1224 double Node::TotalProcessingTime(
1225 absl::flat_hash_map<string, double>* processing_times) {
1226 // Create a hash map to store the per-element CPU time spent in the subtree
1227 // rooted in each node.
1228 absl::flat_hash_map<string, double> total_processing_times;
1229 tf_shared_lock l(mu_);
1230
1231 // Computes per-element CPU time spent in the subtree rooted in the node from
1232 // the leaves of the nodes tree to the root.
1233 for (const auto& node :
1234 CollectNodes(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1235 tf_shared_lock l(node->mu_);
1236 node->TotalProcessingTimeLocked(processing_times, &total_processing_times);
1237 }
1238 TotalProcessingTimeLocked(processing_times, &total_processing_times);
1239
1240 return total_processing_times[long_name()];
1241 }
1242
AverageBufferedElementSize() const1243 double Node::AverageBufferedElementSize() const {
1244 DCHECK_GE(num_elements_, 0);
1245 DCHECK_GE(buffered_elements_, 0);
1246 if (num_elements_ <= 0) {
1247 if (buffered_elements_ <= 0) {
1248 // If there are no produced elements or buffered elements recorded, return
1249 // 0.
1250 return 0;
1251 }
1252 // If there are no produced elements but some buffered elements, return the
1253 // average size of all buffered elements.
1254 return static_cast<double>(buffered_bytes_) /
1255 static_cast<double>(buffered_elements_);
1256 }
1257
1258 if (buffered_elements_ <= 0) {
1259 // If there are no buffered elements but some produced elements, return the
1260 // average size of all produced elements.
1261 return static_cast<double>(bytes_produced_) /
1262 static_cast<double>(num_elements_);
1263 }
1264
1265 // Otherwise, return the mean value of average size of all produced elements
1266 // and average size of all buffered elements.
1267 return (static_cast<double>(bytes_produced_) /
1268 static_cast<double>(num_elements_) +
1269 static_cast<double>(buffered_bytes_) /
1270 static_cast<double>(buffered_elements_)) /
1271 2.0;
1272 }
1273
OutputTimeForInputs(const absl::flat_hash_map<string,double> & output_times) const1274 double Node::OutputTimeForInputs(
1275 const absl::flat_hash_map<string, double>& output_times) const {
1276 double sum = 0;
1277 for (auto& input : inputs_) {
1278 // Inputs for which autotuning is disabled are excluded.
1279 if (input->autotune()) {
1280 sum += output_times.at(input->long_name());
1281 }
1282 }
1283 return sum;
1284 }
1285
OutputTimeGradientsForInputs(const absl::flat_hash_map<string,double> & output_time_gradients) const1286 double Node::OutputTimeGradientsForInputs(
1287 const absl::flat_hash_map<string, double>& output_time_gradients) const {
1288 double sum = 0;
1289 for (auto& input : inputs_) {
1290 // Inputs for which autotuning is disabled are excluded.
1291 if (input->autotune()) {
1292 sum +=
1293 gtl::FindWithDefault(output_time_gradients, input->long_name(), 0.0L);
1294 }
1295 }
1296 return sum;
1297 }
1298
TotalProcessingTimeForInputs(const absl::flat_hash_map<string,double> & total_processing_times)1299 double Node::TotalProcessingTimeForInputs(
1300 const absl::flat_hash_map<string, double>& total_processing_times) {
1301 // If the number of elements produced by an input is smaller than this
1302 // constant, then its processing time is estimated using a weighted average
1303 // of the empirical processing time and processing time history.
1304 constexpr int kNumElementsThreshold = 30;
1305
1306 // Identifies the minimum number of input processing times to collect
1307 // before the processing time history is used as a prior.
1308 constexpr int kCountThreshold = 30;
1309
1310 double sum = 0;
1311 for (auto& input : inputs_) {
1312 // Inputs for which autotuning is disabled are excluded.
1313 if (input->autotune()) {
1314 double input_processing_time =
1315 total_processing_times.at(input->long_name());
1316 int64 num_elements = input->num_elements();
1317 if (num_elements < kNumElementsThreshold) {
1318 if (input_processing_time_count_ < kCountThreshold) {
1319 sum += input_processing_time;
1320 } else {
1321 // The fewer elements the input has produced so far, the more weight
1322 // is assigned to the prior to reduce volatility.
1323 double prior_weight = 1.0L / static_cast<double>(2 << num_elements);
1324 double prior =
1325 input_processing_time_sum_ / input_processing_time_count_;
1326 sum += (1.0L - prior_weight) * input_processing_time +
1327 prior_weight * prior;
1328 }
1329 } else {
1330 sum += input_processing_time;
1331 input_processing_time_count_++;
1332 input_processing_time_sum_ += input_processing_time;
1333 }
1334 }
1335 }
1336 return sum;
1337 }
1338
SelfProcessingTimeLocked() const1339 double Node::SelfProcessingTimeLocked() const {
1340 if (num_elements_ == 0) {
1341 return 0;
1342 }
1343 return static_cast<double>(processing_time_) /
1344 static_cast<double>(num_elements_);
1345 }
1346
CollectNodes(TraversalOrder order,bool collect_node (const std::shared_ptr<Node>)) const1347 Node::NodeVector Node::CollectNodes(
1348 TraversalOrder order, bool collect_node(const std::shared_ptr<Node>)) const
1349 TF_SHARED_LOCKS_REQUIRED(mu_) {
1350 NodeVector node_vector;
1351 std::list<std::shared_ptr<Node>> temp_list;
1352
1353 for (auto& input : inputs_) {
1354 if (collect_node(input)) {
1355 node_vector.push_back(input);
1356 temp_list.push_back(input);
1357 }
1358 }
1359
1360 while (!temp_list.empty()) {
1361 auto cur_node = temp_list.front();
1362 temp_list.pop_front();
1363 tf_shared_lock l(cur_node->mu_);
1364 for (auto& input : cur_node->inputs_) {
1365 if (collect_node(input)) {
1366 node_vector.push_back(input);
1367 temp_list.push_back(input);
1368 }
1369 }
1370 }
1371
1372 if (order == TraversalOrder::REVERSE_BFS) {
1373 std::reverse(node_vector.begin(), node_vector.end());
1374 }
1375 return node_vector;
1376 }
1377
CollectTunableParametersHelper(absl::flat_hash_map<string,std::shared_ptr<Parameter>> * parameters) const1378 void Node::CollectTunableParametersHelper(
1379 absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const
1380 TF_SHARED_LOCKS_REQUIRED(mu_) {
1381 // If autotune is turned off or there are no elements recorded, we don't
1382 // collect the parameters on the node.
1383 if (!autotune_ || num_elements_ <= 0) {
1384 return;
1385 }
1386 for (auto& pair : parameters_) {
1387 if (pair.second->state->tunable) {
1388 parameters->insert(std::make_pair(long_name(), pair.second));
1389 }
1390 }
1391 }
1392
DebugStringHelper(absl::flat_hash_map<string,string> * debug_strings) const1393 void Node::DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings)
1394 const TF_SHARED_LOCKS_REQUIRED(mu_) {
1395 string result;
1396 strings::StrAppend(&result, long_name(), ":\n");
1397 strings::StrAppend(&result, " autotune=", autotune_.load(), "\n");
1398 strings::StrAppend(&result, " buffered_bytes=", buffered_bytes_.load(),
1399 "\n");
1400 strings::StrAppend(&result, " buffered_elements=", buffered_elements_.load(),
1401 "\n");
1402 strings::StrAppend(&result, " bytes_consumed=", bytes_consumed_.load(),
1403 "\n");
1404 strings::StrAppend(&result, " bytes_produced=", bytes_produced_.load(),
1405 "\n");
1406 strings::StrAppend(&result, " processing_time=", processing_time_.load(),
1407 "\n");
1408 strings::StrAppend(&result, " num_elements=", num_elements_.load(), "\n");
1409 string inputs;
1410 for (auto& input : inputs_) {
1411 strings::StrAppend(&inputs, input->long_name(), ",");
1412 }
1413 strings::StrAppend(&result, " inputs={", inputs, "}\n");
1414 for (auto& input : inputs_) {
1415 strings::StrAppend(&result, debug_strings->at(input->long_name()));
1416 }
1417 debug_strings->insert(std::make_pair(long_name(), result));
1418 }
1419
SnapshotHelper(std::shared_ptr<Node> cloned_output,Node::NodePairList * node_pairs) const1420 std::shared_ptr<Node> Node::SnapshotHelper(
1421 std::shared_ptr<Node> cloned_output, Node::NodePairList* node_pairs) const {
1422 tf_shared_lock l(mu_);
1423
1424 // Clone current node(`this`), also set clone of its output node
1425 // (`cloned_output`) to be the output node of the cloned node
1426 // (`cloned_current`).
1427 std::shared_ptr<Node> cloned_current = Clone(cloned_output);
1428 {
1429 cloned_current->autotune_.store(autotune_);
1430 cloned_current->buffered_bytes_.store(buffered_bytes_);
1431 cloned_current->buffered_elements_.store(buffered_elements_);
1432 cloned_current->bytes_consumed_.store(bytes_consumed_);
1433 cloned_current->bytes_produced_.store(bytes_produced_);
1434 cloned_current->num_elements_.store(num_elements_);
1435 cloned_current->record_metrics_.store(false);
1436 cloned_current->processing_time_.store(processing_time_);
1437 mutex_lock l2(cloned_current->mu_);
1438 cloned_current->parameters_ = parameters_;
1439 }
1440
1441 for (auto& input : inputs_) {
1442 node_pairs->push_back(std::make_pair(input, cloned_current));
1443 }
1444 return cloned_current;
1445 }
1446
TotalBufferedBytesHelper(absl::flat_hash_map<string,double> * total_bytes) const1447 void Node::TotalBufferedBytesHelper(
1448 absl::flat_hash_map<string, double>* total_bytes) const
1449 TF_SHARED_LOCKS_REQUIRED(mu_) {
1450 if (!autotune_) {
1451 total_bytes->insert(std::make_pair(long_name(), 0));
1452 return;
1453 }
1454
1455 double result = 0;
1456 auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
1457 if (!parameter) {
1458 parameter = gtl::FindOrNull(parameters_, kParallelism);
1459 }
1460 if (parameter) {
1461 result = buffered_bytes_;
1462 }
1463 for (auto& input : inputs_) {
1464 result += total_bytes->at(input->long_name());
1465 }
1466 total_bytes->insert(std::make_pair(long_name(), result));
1467 }
1468
TotalMaximumBufferedBytesHelper(absl::flat_hash_map<string,double> * total_bytes) const1469 void Node::TotalMaximumBufferedBytesHelper(
1470 absl::flat_hash_map<string, double>* total_bytes) const
1471 TF_SHARED_LOCKS_REQUIRED(mu_) {
1472 if (!autotune_) {
1473 total_bytes->insert(std::make_pair(long_name(), 0));
1474 return;
1475 }
1476
1477 double result = MaximumBufferedBytes();
1478 for (auto& input : inputs_) {
1479 result += total_bytes->at(input->long_name());
1480 }
1481 total_bytes->insert(std::make_pair(long_name(), result));
1482 }
1483
MaximumBufferedBytes() const1484 double Node::MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
1485 return 0;
1486 }
1487
ToProto(ModelProto::Node * node_proto) const1488 Status Node::ToProto(ModelProto::Node* node_proto) const {
1489 tf_shared_lock l(mu_);
1490 node_proto->set_id(id_);
1491 node_proto->set_name(name_);
1492 node_proto->set_autotune(autotune_);
1493 node_proto->set_buffered_bytes(buffered_bytes_);
1494 node_proto->set_buffered_elements(buffered_elements_);
1495 node_proto->set_bytes_consumed(bytes_consumed_);
1496 node_proto->set_bytes_produced(bytes_produced_);
1497 node_proto->set_num_elements(num_elements_);
1498 node_proto->set_processing_time(processing_time_);
1499 node_proto->set_record_metrics(record_metrics_);
1500
1501 // Produce protos for all parameters.
1502 for (auto const& parameter : parameters_) {
1503 ModelProto::Node::Parameter* parameter_proto = node_proto->add_parameters();
1504 parameter_proto->set_name(parameter.first);
1505 parameter_proto->set_value(parameter.second->value);
1506 parameter_proto->set_min(parameter.second->min);
1507 parameter_proto->set_max(parameter.second->max);
1508 parameter_proto->set_state_value(parameter.second->state->value);
1509 parameter_proto->set_tunable(parameter.second->state->tunable);
1510 }
1511
1512 // Produce protos for all inputs.
1513 for (auto const& input : inputs_) {
1514 ModelProto::Node* input_proto = node_proto->add_inputs();
1515 TF_RETURN_IF_ERROR(input->ToProto(input_proto));
1516 }
1517 return Status::OK();
1518 }
1519
FromProtoHelper(ModelProto::Node node_proto,std::shared_ptr<Node> node)1520 Status Node::FromProtoHelper(ModelProto::Node node_proto,
1521 std::shared_ptr<Node> node) {
1522 tf_shared_lock l(node->mu_);
1523 node->autotune_.store(node_proto.autotune());
1524 node->buffered_bytes_.store(node_proto.buffered_bytes());
1525 node->buffered_elements_.store(node_proto.buffered_elements());
1526 node->bytes_consumed_.store(node_proto.bytes_consumed());
1527 node->bytes_produced_.store(node_proto.bytes_produced());
1528 node->num_elements_.store(node_proto.num_elements());
1529 node->processing_time_.store(node_proto.processing_time());
1530 node->record_metrics_.store(node_proto.record_metrics());
1531
1532 // Restore parameters.
1533 int64 num_parameters = node_proto.parameters_size();
1534 for (int i = 0; i < num_parameters; i++) {
1535 const ModelProto::Node::Parameter& parameter_proto =
1536 node_proto.parameters(i);
1537 std::shared_ptr<SharedState> state;
1538 if (parameter_proto.tunable()) {
1539 state =
1540 std::make_shared<SharedState>(kAutotune, std::make_shared<mutex>(),
1541 std::make_shared<condition_variable>());
1542 state->value = parameter_proto.state_value();
1543 } else {
1544 state = std::make_shared<SharedState>(
1545 parameter_proto.state_value(), std::make_shared<mutex>(),
1546 std::make_shared<condition_variable>());
1547 }
1548 node->parameters_[parameter_proto.name()] =
1549 MakeParameter(parameter_proto.name(), state, parameter_proto.min(),
1550 parameter_proto.max());
1551 }
1552 return Status::OK();
1553 }
1554
FromProto(ModelProto::Node node_proto,std::shared_ptr<Node> output,std::shared_ptr<Node> * node)1555 Status Node::FromProto(ModelProto::Node node_proto,
1556 std::shared_ptr<Node> output,
1557 std::shared_ptr<Node>* node) {
1558 // Note that parameters are restored in `FromProtoHelper`.
1559 Args args = {node_proto.id(), node_proto.name(), std::move(output)};
1560 std::shared_ptr<Node> restored_node;
1561 switch (node_proto.node_class()) {
1562 case NodeClass::INTERLEAVE_MANY:
1563 restored_node = std::make_shared<InterleaveMany>(args);
1564 break;
1565 case NodeClass::ASYNC_INTERLEAVE_MANY:
1566 restored_node = std::make_shared<AsyncInterleaveMany>(
1567 args, /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
1568 break;
1569 case NodeClass::KNOWN_RATIO:
1570 restored_node = std::make_shared<KnownRatio>(args, node_proto.ratio());
1571 break;
1572 case NodeClass::ASYNC_KNOWN_RATIO:
1573 restored_node = std::make_shared<AsyncKnownRatio>(
1574 args, node_proto.ratio(), node_proto.memory_ratio(),
1575 /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
1576 break;
1577 case NodeClass::UNKNOWN_RATIO:
1578 restored_node = std::make_shared<UnknownRatio>(args);
1579 break;
1580 default:
1581 restored_node = std::make_shared<Unknown>(args);
1582 }
1583 TF_RETURN_IF_ERROR(FromProtoHelper(node_proto, restored_node));
1584
1585 // Restore all input nodes as well.
1586 int64 num_inputs = node_proto.inputs_size();
1587 for (int64 i = 0; i < num_inputs; i++) {
1588 const ModelProto::Node& input_proto = node_proto.inputs(i);
1589 std::shared_ptr<Node> input;
1590 TF_RETURN_IF_ERROR(FromProto(input_proto, restored_node, &input));
1591 restored_node->add_input(input);
1592 }
1593 (*node) = std::move(restored_node);
1594 return Status::OK();
1595 }
1596
AddNode(Node::Factory factory,const string & name,std::shared_ptr<Node> parent,std::shared_ptr<Node> * out_node)1597 void Model::AddNode(Node::Factory factory, const string& name,
1598 std::shared_ptr<Node> parent,
1599 std::shared_ptr<Node>* out_node) {
1600 // The name captures the sequence of iterators joined by `::`. We only use the
1601 // last element of the sequence as the name node.
1602 auto node_name = str_util::Split(name, ':', str_util::SkipEmpty()).back();
1603 mutex_lock l(mu_);
1604 std::shared_ptr<Node> node = factory({id_counter_++, node_name, parent});
1605 if (!output_) {
1606 output_ = node;
1607 }
1608 if (parent) {
1609 VLOG(3) << "Adding " << node->long_name() << " as input for "
1610 << parent->long_name();
1611 parent->add_input(node);
1612 } else {
1613 VLOG(3) << "Adding " << node->long_name();
1614 }
1615 collect_resource_usage_ =
1616 collect_resource_usage_ || node->has_tunable_parameters();
1617 *out_node = std::move(node);
1618 // TODO(jsimsa): Reset the optimization period when a node is added so that
1619 // autotuning adapts to changes to the input pipeline faster. Initial attempt
1620 // to enable this functionality caused a regression (see b/179812091).
1621 }
1622
FlushMetrics()1623 void Model::FlushMetrics() {
1624 std::deque<std::shared_ptr<Node>> queue;
1625 {
1626 tf_shared_lock l(mu_);
1627 if (output_) queue.push_back(output_);
1628 }
1629 while (!queue.empty()) {
1630 auto node = queue.front();
1631 queue.pop_front();
1632 node->FlushMetrics();
1633 for (auto input : node->inputs()) {
1634 queue.push_back(input);
1635 }
1636 }
1637 }
1638
Optimize(AutotuneAlgorithm algorithm,int64 cpu_budget,int64 ram_budget,double model_input_time)1639 void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
1640 int64 ram_budget, double model_input_time) {
1641 std::shared_ptr<Node> snapshot;
1642 {
1643 tf_shared_lock lock(mu_);
1644 snapshot = output_->Snapshot();
1645 }
1646 OptimizationParams optimization_params;
1647 optimization_params.set_algorithm(algorithm);
1648 optimization_params.set_cpu_budget(cpu_budget);
1649 optimization_params.set_ram_budget(ram_budget);
1650 optimization_params.set_model_input_time(model_input_time);
1651 switch (algorithm) {
1652 case AutotuneAlgorithm::HILL_CLIMB:
1653 OptimizeHillClimb(snapshot, optimization_params);
1654 break;
1655 case AutotuneAlgorithm::GRADIENT_DESCENT:
1656 OptimizeGradientDescent(snapshot, optimization_params);
1657 break;
1658 default:
1659 VLOG(2) << "Autotuning algorithm was not recognized. Aborting "
1660 "optimization.";
1661 return;
1662 }
1663 if (!save_dir_.empty()) {
1664 mutex_lock lock(mu_);
1665 Status status = EnsureSaveLoopThreadStarted();
1666 if (status.ok() && save_buffer_.size() < kMaxNumBufferedOptimizeArgs) {
1667 save_buffer_.push_back(std::make_pair(snapshot, optimization_params));
1668 save_cond_var_.notify_all();
1669 } else if (save_buffer_.size() >= kMaxNumBufferedOptimizeArgs) {
1670 VLOG(3) << "Saved snapshots buffer is full. Current snapshot and "
1671 "optimization parameters will not be saved.";
1672 }
1673 }
1674 }
1675
RemoveNode(std::shared_ptr<Node> node)1676 void Model::RemoveNode(std::shared_ptr<Node> node) {
1677 mutex_lock l(mu_);
1678 if (node) {
1679 if (node->output()) {
1680 node->output()->remove_input(node);
1681 }
1682 VLOG(3) << "Removing " << node->long_name();
1683 }
1684 }
1685
1686 absl::flat_hash_map<string, std::shared_ptr<Parameter>>
CollectTunableParameters(std::shared_ptr<Node> node)1687 Model::CollectTunableParameters(std::shared_ptr<Node> node) {
1688 absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters;
1689 node->CollectTunableParameters(¶meters);
1690 return parameters;
1691 }
1692
ShouldStop(int64 cpu_budget,int64 ram_budget,const absl::flat_hash_map<string,std::shared_ptr<Parameter>> & parameters,const absl::flat_hash_map<string,std::shared_ptr<Parameter>> & parallelism_parameters,const absl::flat_hash_map<string,std::shared_ptr<Parameter>> & buffer_size_parameters,std::shared_ptr<Node> snapshot,bool * cpu_budget_reached)1693 bool Model::ShouldStop(
1694 int64 cpu_budget, int64 ram_budget,
1695 const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
1696 const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
1697 parallelism_parameters,
1698 const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
1699 buffer_size_parameters,
1700 std::shared_ptr<Node> snapshot, bool* cpu_budget_reached) {
1701 if (!(*cpu_budget_reached)) {
1702 // If those essential transformations' parallelism reaches the CPU
1703 // budget, we will only tune the buffer size parameters in future
1704 // iterations.
1705 int64 model_parallelism = 0;
1706 for (auto& pair : parallelism_parameters) {
1707 model_parallelism += std::round(pair.second->value);
1708 }
1709 *cpu_budget_reached = (model_parallelism > cpu_budget);
1710 }
1711
1712 bool all_max = true;
1713 for (auto& pair :
1714 (*cpu_budget_reached ? buffer_size_parameters : parameters)) {
1715 if (std::round(pair.second->value) < pair.second->max) {
1716 all_max = false;
1717 break;
1718 }
1719 }
1720
1721 // If all parameters have reached their maximum values or RAM budget is
1722 // reached, we stop the iterations.
1723 return all_max || TotalMaximumBufferedBytes(snapshot) > ram_budget;
1724 }
1725
1726 // TODO(jsimsa): Add support for tracking and using the model input time.
OptimizeLoop(AutotuneAlgorithm algorithm,int64 cpu_budget,int64 ram_budget,CancellationManager * cancellation_manager)1727 Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, int64 cpu_budget,
1728 int64 ram_budget,
1729 CancellationManager* cancellation_manager) {
1730 std::function<void()> unused;
1731 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
1732 cancellation_manager,
1733 [this]() {
1734 mutex_lock l(mu_);
1735 optimize_cond_var_.notify_all();
1736 },
1737 /*deregister_fn=*/&unused));
1738
1739 int64 last_optimization_ms = 0;
1740 int64 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1741 while (true) {
1742 {
1743 mutex_lock l(mu_);
1744 while (!cancellation_manager->IsCancelled() &&
1745 last_optimization_ms + optimization_period_ms_ > current_time_ms) {
1746 auto wait_ms =
1747 last_optimization_ms + optimization_period_ms_ - current_time_ms;
1748 VLOG(2) << "Waiting for " << wait_ms << " ms.";
1749 optimize_cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
1750 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1751 }
1752 if (cancellation_manager->IsCancelled()) {
1753 return Status::OK();
1754 }
1755 }
1756
1757 int64 optimization_start_us = EnvTime::NowMicros();
1758 Optimize(algorithm, cpu_budget, ram_budget, /*model_input_time=*/0);
1759 VLOG(2) << "Optimized for "
1760 << (EnvTime::NowMicros() - optimization_start_us) << " us.";
1761
1762 // Exponentially increase the period of running the optimization
1763 // until a threshold is reached.
1764 {
1765 mutex_lock l(mu_);
1766 optimization_period_ms_ =
1767 std::min(optimization_period_ms_ << 1, kOptimizationPeriodMaxMs);
1768 }
1769 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
1770 last_optimization_ms = current_time_ms;
1771 FlushMetrics();
1772 }
1773 }
1774
OptimizeGradientDescent(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params)1775 void Model::OptimizeGradientDescent(
1776 std::shared_ptr<Node> snapshot,
1777 const OptimizationParams& optimization_params) {
1778 VLOG(2) << "Starting optimization of tunable parameters with Gradient "
1779 "Descent.";
1780 auto parameters = CollectTunableParameters(snapshot);
1781 if (parameters.empty()) {
1782 VLOG(2) << "The Gradient Descent optimization is terminated since no node "
1783 "with tunable parameters has recorded elements.";
1784 return;
1785 }
1786
1787 // The maps of "essential" parallelism parameters and buffer size parameters.
1788 absl::flat_hash_map<string, std::shared_ptr<Parameter>>
1789 parallelism_parameters, buffer_size_parameters;
1790 CollectParameters(snapshot, parameters, ¶llelism_parameters,
1791 &buffer_size_parameters);
1792
1793 // Initialize the parameter values to minimal before tuning.
1794 for (auto& pair : parameters) {
1795 pair.second->value = pair.second->min;
1796 }
1797
1798 // Optimization is stopped once the `OutputTime` improvement is smaller than
1799 // this value.
1800 constexpr double kOptimizationPrecision = 100.0L;
1801
1802 // Maximum number of iterations for optimization.
1803 constexpr int64 kMaxIterations = 1000;
1804
1805 double output_time = 0;
1806 double new_output_time;
1807
1808 // When the CPU budget is reached, the parallelism parameter values are fixed
1809 // and we only increase the buffer size parameters.
1810 bool cpu_budget_reached = false;
1811
1812 for (int i = 0; i < kMaxIterations &&
1813 !ShouldStop(optimization_params.cpu_budget(),
1814 optimization_params.ram_budget(), parameters,
1815 parallelism_parameters, buffer_size_parameters,
1816 snapshot, &cpu_budget_reached);
1817 ++i) {
1818 absl::flat_hash_map<string, double> gradients;
1819 new_output_time = OutputTime(
1820 snapshot, optimization_params.model_input_time(), &gradients);
1821 // We also terminate once the improvement of the output latency is too
1822 // small.
1823 if (std::abs(output_time - new_output_time) < kOptimizationPrecision) {
1824 break;
1825 }
1826
1827 UpdateParameterValues(
1828 gradients, &(cpu_budget_reached ? buffer_size_parameters : parameters));
1829 output_time = new_output_time;
1830 }
1831
1832 for (auto& pair : parameters) {
1833 pair.second->value = std::round(pair.second->value);
1834 }
1835 UpdateStateValues(¶meters);
1836 }
1837
OptimizeHillClimb(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params)1838 void Model::OptimizeHillClimb(std::shared_ptr<Node> snapshot,
1839 const OptimizationParams& optimization_params) {
1840 VLOG(2) << "Starting optimization of tunable parameters with Hill Climb.";
1841 const double processing_time = TotalProcessingTime(snapshot);
1842 auto parameters = CollectTunableParameters(snapshot);
1843 if (parameters.empty()) {
1844 VLOG(2) << "The Hill Climb optimization is terminated since no node with "
1845 "tunable parameters has recorded elements.";
1846 return;
1847 }
1848
1849 // Buffer size parameter will only be incremented if the output latency
1850 // improvement is greater than this constant.
1851 constexpr double kBufferSizeMinDelta = 1.0L;
1852
1853 // Initialize the parameter values to minimal before tuning.
1854 for (auto& pair : parameters) {
1855 pair.second->value = pair.second->min;
1856 }
1857 while (true) {
1858 const double output_time =
1859 OutputTime(snapshot, optimization_params.model_input_time(),
1860 /*gradients=*/nullptr);
1861 bool all_max = true;
1862 for (auto& pair : parameters) {
1863 if (pair.second->value < pair.second->max) {
1864 all_max = false;
1865 break;
1866 }
1867 }
1868 if (output_time < processing_time / optimization_params.cpu_budget() ||
1869 all_max ||
1870 TotalMaximumBufferedBytes(snapshot) >
1871 optimization_params.ram_budget()) {
1872 break;
1873 }
1874 double best_delta = -1.0L;
1875 Parameter* best_parameter = nullptr;
1876 for (auto& pair : parameters) {
1877 if (pair.second->value >= pair.second->max) {
1878 continue;
1879 }
1880 pair.second->value++;
1881 double new_output_time =
1882 OutputTime(snapshot, optimization_params.model_input_time(),
1883 /*gradients=*/nullptr);
1884 double delta = output_time - new_output_time;
1885 if (delta > best_delta &&
1886 (delta > kBufferSizeMinDelta || pair.second->name != kBufferSize)) {
1887 best_delta = delta;
1888 best_parameter = pair.second.get();
1889 }
1890 pair.second->value--;
1891 }
1892 if (!best_parameter) {
1893 VLOG(2) << "Failed to find a tunable parameter that would further "
1894 "decrease the output time. This means that the autotuning "
1895 "optimization got stuck in a local maximum. The optimization "
1896 "attempt will terminate early.";
1897 break;
1898 }
1899 best_parameter->value++;
1900 }
1901 UpdateStateValues(¶meters);
1902 }
1903
OutputTime(std::shared_ptr<Node> node,double model_input_time,absl::flat_hash_map<string,double> * gradients)1904 double Model::OutputTime(std::shared_ptr<Node> node, double model_input_time,
1905 absl::flat_hash_map<string, double>* gradients) {
1906 // To store the input time for each node.
1907 absl::flat_hash_map<string, double> input_times = {
1908 {kModelInputTimeKey, model_input_time}};
1909
1910 // TODO(jsimsa): Now that we are accounting for buffer size in wait time
1911 // computation, assuming that the input is infinitely fast will result in
1912 // inaccurate estimates of the output latency.
1913 //
1914 // We should compute the output latency as a fix-point of the following
1915 // equation: `output_time = node(OutputTime(input_times(1, output_time))`.
1916
1917 return node->OutputTime(&input_times, gradients);
1918 }
1919
TotalBufferedBytes(std::shared_ptr<Node> node)1920 double Model::TotalBufferedBytes(std::shared_ptr<Node> node) {
1921 return node->TotalBufferedBytes();
1922 }
1923
TotalMaximumBufferedBytes(std::shared_ptr<Node> node)1924 double Model::TotalMaximumBufferedBytes(std::shared_ptr<Node> node) {
1925 return node->TotalMaximumBufferedBytes();
1926 }
1927
TotalProcessingTime(std::shared_ptr<Node> node)1928 double Model::TotalProcessingTime(std::shared_ptr<Node> node) {
1929 return node->TotalProcessingTime(/*processing_times=*/nullptr);
1930 }
1931
ToProto(ModelProto * model_proto)1932 Status Model::ToProto(ModelProto* model_proto) {
1933 ModelProto::Node* output_proto = model_proto->mutable_output();
1934 tf_shared_lock lock(mu_);
1935 TF_RETURN_IF_ERROR(output_->ToProto(output_proto));
1936 model_proto->set_id_counter(id_counter_);
1937 model_proto->set_collect_resource_usage(collect_resource_usage_);
1938 return Status::OK();
1939 }
1940
FromProto(ModelProto model_proto,std::unique_ptr<Model> * model)1941 Status Model::FromProto(ModelProto model_proto, std::unique_ptr<Model>* model) {
1942 std::unique_ptr<Model> restored_model = std::make_unique<Model>();
1943 std::shared_ptr<Node> output;
1944 TF_RETURN_IF_ERROR(
1945 Node::FromProto(model_proto.output(), /*output=*/nullptr, &output));
1946 mutex_lock lock(restored_model->mu_);
1947 restored_model->output_ = output;
1948 restored_model->id_counter_ = model_proto.id_counter();
1949 restored_model->collect_resource_usage_.store(
1950 model_proto.collect_resource_usage());
1951 *model = std::move(restored_model);
1952 return Status::OK();
1953 }
1954
Save(const string & fname,std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params)1955 Status Model::Save(const string& fname, std::shared_ptr<Node> snapshot,
1956 const OptimizationParams& optimization_params) {
1957 ModelProto model_proto;
1958 std::unique_ptr<Model> model_snapshot = std::make_unique<Model>();
1959 {
1960 mutex_lock lock(model_snapshot->mu_);
1961 model_snapshot->output_ = std::move(snapshot);
1962 model_snapshot->id_counter_ = id_counter_;
1963 model_snapshot->collect_resource_usage_.store(collect_resource_usage_);
1964 }
1965 TF_RETURN_IF_ERROR(model_snapshot->ToProto(&model_proto));
1966 OptimizationParams* saved_optimization_params =
1967 model_proto.mutable_optimization_params();
1968 *saved_optimization_params = optimization_params;
1969 return WriteBinaryProto(Env::Default(), fname, model_proto);
1970 }
1971
Load(const string & fname,std::unique_ptr<Model> * model,OptimizationParams * optimization_params)1972 Status Model::Load(const string& fname, std::unique_ptr<Model>* model,
1973 OptimizationParams* optimization_params) {
1974 ModelProto model_proto;
1975 TF_RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), fname, &model_proto));
1976 TF_RETURN_IF_ERROR(FromProto(model_proto, model));
1977 const OptimizationParams restored_optimization_params =
1978 model_proto.optimization_params();
1979 *optimization_params = restored_optimization_params;
1980 return Status::OK();
1981 }
1982
EnsureSaveLoopThreadStarted()1983 Status Model::EnsureSaveLoopThreadStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1984 if (!save_thread_) {
1985 save_thread_ = absl::WrapUnique(
1986 Env::Default()->StartThread({}, "tf_data_model_save", [this]() {
1987 Status status = SaveLoop();
1988 if (!status.ok()) {
1989 VLOG(2) << "Model save loop failed: " << status.ToString();
1990 }
1991 }));
1992 }
1993 return Status::OK();
1994 }
1995
SaveLoop()1996 Status Model::SaveLoop() {
1997 TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(save_dir_));
1998 while (true) {
1999 std::pair<std::shared_ptr<Node>, OptimizationParams> to_save;
2000 {
2001 mutex_lock l(mu_);
2002 while (!save_thread_cancelled_ && save_buffer_.empty()) {
2003 save_cond_var_.wait(l);
2004 }
2005 if (save_thread_cancelled_) {
2006 return Status::OK();
2007 }
2008 to_save = save_buffer_.front();
2009 save_buffer_.pop_front();
2010 }
2011 string model_name =
2012 absl::StrCat("autotune_model_",
2013 Hash64Combine(static_cast<uint64>(EnvTime::NowMicros()),
2014 reinterpret_cast<uint64>(this)));
2015 string fname = io::JoinPath(save_dir_, model_name);
2016 TF_RETURN_IF_ERROR(Save(fname, to_save.first, to_save.second));
2017 VLOG(2) << "Model was saved as " << fname;
2018 }
2019 }
2020
2021 } // namespace model
2022 } // namespace data
2023 } // namespace tensorflow
2024