1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/model.h"
17
18 #include <memory>
19
20 namespace tensorflow {
21 namespace data {
22 namespace model {
23
MakeParameter(const string & name,std::shared_ptr<SharedState> state,int64 min,int64 max)24 std::shared_ptr<Parameter> MakeParameter(const string& name,
25 std::shared_ptr<SharedState> state,
26 int64 min, int64 max) {
27 return std::make_shared<Parameter>(name, state, min, max);
28 }
29
30 namespace {
31
32 // Given the average time between output events (`output_time`), the average
33 // time between input events (`input_time`) and the buffer size, the method
34 // computes the expected time an input event will have to wait.
35 //
36 // The wait time is approximated as the product of the probability the buffer
37 // will be empty and the time it takes to produce an element into the buffer.
38 //
39 // The formula used for computing the probability is derived by modeling the
40 // problem as an M/M/1/K queue
41 // (https://en.wikipedia.org/wiki/Birth%E2%80%93death_process#M/M/1/K_queue).
ComputeWaitTime(int64 output_time,int64 input_time,int64 buffer_size)42 int64 ComputeWaitTime(int64 output_time, int64 input_time, int64 buffer_size) {
43 if (output_time == 0 || input_time == 0) {
44 return output_time;
45 }
46 if (input_time == output_time) {
47 const double p_buffer_empty = 1.0L / static_cast<double>(buffer_size + 1);
48 return p_buffer_empty * output_time;
49 }
50 const double alpha = 1.0L / static_cast<double>(input_time);
51 const double beta = 1.0L / static_cast<double>(output_time);
52 const double p_buffer_empty =
53 (1.0L - beta / alpha) /
54 (1.0L - std::pow((beta / alpha), static_cast<double>(buffer_size + 1)));
55 return p_buffer_empty * output_time;
56 }
57
58 // The first input of InterleaveMany corresponds to the input dataset whose
59 // elements are used to create the (derived) input datasets whose elements are
60 // interleaved as output.
61 //
62 // TODO(jsimsa): model the first input
63 class InterleaveMany : public Node {
64 public:
65 using Node::Node;
66
~InterleaveMany()67 virtual ~InterleaveMany() {}
68
69 protected:
Clone(std::shared_ptr<Node> output) const70 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
71 SHARED_LOCKS_REQUIRED(mu_) {
72 return std::make_shared<InterleaveMany>(
73 Args{id_, name_, std::move(output)});
74 }
75
OutputTimeLocked(std::vector<int64> * input_times) const76 int64 OutputTimeLocked(std::vector<int64>* input_times) const override
77 SHARED_LOCKS_REQUIRED(mu_) {
78 if (inputs_.size() <= 1) {
79 return NanosPerElementLocked();
80 }
81 int64 delta = NanosPerElementLocked() * (inputs_.size() - 1);
82 input_times->back() += delta;
83 auto cleanup = gtl::MakeCleanup(
84 [input_times, delta]() { input_times->back() -= delta; });
85 int64 output_time =
86 static_cast<double>(OutputTimeForInputs(input_times) -
87 inputs_.front()->OutputTime(input_times)) /
88 static_cast<double>(inputs_.size() - 1);
89 return NanosPerElementLocked() + output_time;
90 }
91
ProcessingTimeLocked() const92 int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) {
93 if (inputs_.size() <= 1) {
94 return NanosPerElementLocked();
95 }
96 int64 processing_time =
97 static_cast<double>(ProcessingTimeForInputs() -
98 inputs_.front()->ProcessingTime()) /
99 static_cast<double>(inputs_.size() - 1);
100 return NanosPerElementLocked() + processing_time;
101 }
102 };
103
104 // TODO(jsimsa): model the first input
105 class AsyncInterleaveMany : public Node {
106 public:
AsyncInterleaveMany(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)107 AsyncInterleaveMany(Node::Args args,
108 std::vector<std::shared_ptr<Parameter>> parameters)
109 : Node(args) {
110 for (auto& parameter : parameters) {
111 parameters_[parameter->name] = std::move(parameter);
112 }
113 }
114
~AsyncInterleaveMany()115 virtual ~AsyncInterleaveMany() {}
116
117 protected:
Clone(std::shared_ptr<Node> output) const118 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
119 SHARED_LOCKS_REQUIRED(mu_) {
120 std::vector<std::shared_ptr<Parameter>> parameters;
121 for (auto& pair : parameters_) {
122 parameters.push_back(pair.second);
123 }
124 return std::make_shared<AsyncInterleaveMany>(
125 Args{id_, name_, std::move(output)}, parameters);
126 }
127
OutputTimeLocked(std::vector<int64> * input_times) const128 int64 OutputTimeLocked(std::vector<int64>* input_times) const override
129 SHARED_LOCKS_REQUIRED(mu_) {
130 if (inputs_.size() <= 1) {
131 return NanosPerElementLocked();
132 }
133 int64 old_input_time = input_times->back();
134 int64 new_input_time = static_cast<double>(NanosPerElementLocked()) *
135 static_cast<double>(inputs_.size() - 1);
136 input_times->push_back(new_input_time);
137 auto cleanup =
138 gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
139 double parallelism = inputs_.size() - 1; // default to cycle length
140 if (auto* parameter = gtl::FindOrNull(parameters_, "parallelism")) {
141 parallelism = std::min(static_cast<int>(parallelism),
142 static_cast<int>((*parameter)->value));
143 }
144 int64 output_time =
145 static_cast<double>(OutputTimeForInputs(input_times) -
146 inputs_.front()->OutputTime(input_times)) /
147 static_cast<double>(inputs_.size() - 1) / parallelism;
148 return ComputeWaitTime(NanosPerElementLocked() + output_time,
149 old_input_time, parallelism);
150 }
151
ProcessingTimeLocked() const152 int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) {
153 if (inputs_.size() <= 1) {
154 return NanosPerElementLocked();
155 }
156 int64 processing_time =
157 ProcessingTimeForInputs() - inputs_.front()->ProcessingTime();
158 return NanosPerElementLocked() +
159 static_cast<double>(processing_time) /
160 static_cast<double>(inputs_.size() - 1);
161 }
162 };
163
164 class KnownRatio : public Node {
165 public:
KnownRatio(Node::Args args,int64 ratio)166 KnownRatio(Node::Args args, int64 ratio) : Node(args), ratio_(ratio) {}
167
~KnownRatio()168 virtual ~KnownRatio() {}
169
170 protected:
Clone(std::shared_ptr<Node> output) const171 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
172 SHARED_LOCKS_REQUIRED(mu_) {
173 return std::make_shared<KnownRatio>(Args{id_, name_, std::move(output)},
174 ratio_);
175 }
176
OutputTimeLocked(std::vector<int64> * input_times) const177 int64 OutputTimeLocked(std::vector<int64>* input_times) const override
178 SHARED_LOCKS_REQUIRED(mu_) {
179 if (ratio_ == 0) {
180 return NanosPerElementLocked();
181 }
182 int64 old_input_time = input_times->back();
183 input_times->back() += static_cast<int64>(
184 static_cast<double>(old_input_time + NanosPerElementLocked()) / ratio_);
185 auto cleanup = gtl::MakeCleanup([input_times, old_input_time]() {
186 input_times->back() = old_input_time;
187 });
188 return NanosPerElementLocked() + ratio_ * OutputTimeForInputs(input_times);
189 }
190
ProcessingTimeLocked() const191 int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) {
192 return NanosPerElementLocked() + ratio_ * ProcessingTimeForInputs();
193 }
194
195 private:
196 const double ratio_;
197 };
198
199 class AsyncKnownRatio : public Node {
200 public:
AsyncKnownRatio(Node::Args args,double ratio,std::vector<std::shared_ptr<Parameter>> parameters)201 AsyncKnownRatio(Node::Args args, double ratio,
202 std::vector<std::shared_ptr<Parameter>> parameters)
203 : Node(args), ratio_(ratio) {
204 for (auto& parameter : parameters) {
205 parameters_[parameter->name] = std::move(parameter);
206 }
207 }
208
~AsyncKnownRatio()209 virtual ~AsyncKnownRatio() {}
210
211 protected:
Clone(std::shared_ptr<Node> output) const212 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
213 SHARED_LOCKS_REQUIRED(mu_) {
214 std::vector<std::shared_ptr<Parameter>> parameters;
215 for (auto& pair : parameters_) {
216 parameters.push_back(pair.second);
217 }
218 return std::make_shared<AsyncKnownRatio>(
219 Args{id_, name_, std::move(output)}, ratio_, parameters);
220 }
221
OutputTimeLocked(std::vector<int64> * input_times) const222 int64 OutputTimeLocked(std::vector<int64>* input_times) const override
223 SHARED_LOCKS_REQUIRED(mu_) {
224 double parallelism = 1.0;
225 if (auto* parameter = gtl::FindOrNull(parameters_, "parallelism")) {
226 parallelism = (*parameter)->value;
227 }
228 if (ratio_ == 0.0) {
229 int64 output_time =
230 static_cast<double>(NanosPerElementLocked()) / parallelism;
231 return ComputeWaitTime(output_time, input_times->back(), parallelism);
232 }
233 int64 old_input_time = input_times->back();
234 int64 new_input_time = static_cast<int64>(
235 static_cast<double>(NanosPerElementLocked()) / ratio_ / parallelism);
236 input_times->push_back(new_input_time);
237 auto cleanup =
238 gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
239 int64 output_time = static_cast<int64>(
240 static_cast<double>(NanosPerElementLocked()) / parallelism +
241 ratio_ * OutputTimeForInputs(input_times));
242 return ComputeWaitTime(output_time, old_input_time, parallelism);
243 }
244
ProcessingTimeLocked() const245 int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) {
246 return NanosPerElementLocked() + ratio_ * ProcessingTimeForInputs();
247 }
248
249 private:
250 const double ratio_;
251 };
252
253 class UnknownRatio : public Node {
254 public:
255 using Node::Node;
256
~UnknownRatio()257 virtual ~UnknownRatio() {}
258
259 protected:
Clone(std::shared_ptr<Node> output) const260 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
261 SHARED_LOCKS_REQUIRED(mu_) {
262 return std::make_shared<UnknownRatio>(Args{id_, name_, std::move(output)});
263 }
264
OutputTimeLocked(std::vector<int64> * input_times) const265 int64 OutputTimeLocked(std::vector<int64>* input_times) const override
266 SHARED_LOCKS_REQUIRED(mu_) {
267 if (num_elements_ == 0 || inputs_.empty() ||
268 inputs_.front()->num_elements() == 0) {
269 return NanosPerElementLocked();
270 }
271 // TODO(jsimsa): The current implementation assumes that the number of input
272 // elements consumed per output is the same across all inputs.
273 std::shared_ptr<Node> input = inputs_.front();
274 double ratio = static_cast<double>(input->num_elements()) /
275 static_cast<double>(num_elements_);
276 int64 old_input_time = input_times->back();
277 input_times->back() =
278 static_cast<double>(old_input_time + NanosPerElementLocked()) / ratio;
279 auto cleanup = gtl::MakeCleanup([input_times, old_input_time]() {
280 input_times->back() = old_input_time;
281 });
282 return NanosPerElementLocked() +
283 static_cast<int64>(
284 ratio * static_cast<double>(OutputTimeForInputs(input_times)));
285 }
286
ProcessingTimeLocked() const287 int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) {
288 if (inputs_.empty() || num_elements_ == 0) {
289 return NanosPerElementLocked();
290 }
291 // TODO(jsimsa): The current implementation that the number of input
292 // elements consumed per output is the same across all inputs.
293 std::shared_ptr<Node> input = inputs_.front();
294 double ratio = static_cast<double>(input->num_elements()) /
295 static_cast<double>(num_elements_);
296 return NanosPerElementLocked() +
297 static_cast<int64>(ratio *
298 static_cast<double>(ProcessingTimeForInputs()));
299 }
300 };
301
302 class Unknown : public Node {
303 public:
304 using Node::Node;
305
~Unknown()306 virtual ~Unknown() {}
307
308 protected:
Clone(std::shared_ptr<Node> output) const309 std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
310 SHARED_LOCKS_REQUIRED(mu_) {
311 return std::make_shared<Unknown>(Args{id_, name_, std::move(output)});
312 }
313
OutputTimeLocked(std::vector<int64> * input_times) const314 int64 OutputTimeLocked(std::vector<int64>* input_times) const override
315 SHARED_LOCKS_REQUIRED(mu_) {
316 return OutputTimeForInputs(input_times);
317 }
318
ProcessingTimeLocked() const319 int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) {
320 return ProcessingTimeForInputs();
321 }
322 };
323
324 } // namespace
325
MakeInterleaveManyNode(Node::Args args)326 std::shared_ptr<Node> MakeInterleaveManyNode(Node::Args args) {
327 return std::make_shared<InterleaveMany>(std::move(args));
328 }
329
MakeAsyncInterleaveManyNode(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)330 std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
331 Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters) {
332 return std::make_shared<AsyncInterleaveMany>(std::move(args),
333 std::move(parameters));
334 }
335
MakeKnownRatioNode(Node::Args args,double ratio)336 std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio) {
337 return std::make_shared<KnownRatio>(std::move(args), ratio);
338 }
339
MakeAsyncKnownRatioNode(Node::Args args,double ratio,std::vector<std::shared_ptr<Parameter>> parameters)340 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
341 Node::Args args, double ratio,
342 std::vector<std::shared_ptr<Parameter>> parameters) {
343 return std::make_shared<AsyncKnownRatio>(std::move(args), ratio,
344 std::move(parameters));
345 }
346
MakeSourceNode(Node::Args args)347 std::shared_ptr<Node> MakeSourceNode(Node::Args args) {
348 return MakeKnownRatioNode(std::move(args), 0);
349 }
350
MakeUnknownRatioNode(Node::Args args)351 std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args) {
352 return std::make_shared<UnknownRatio>(std::move(args));
353 }
354
MakeUnknownNode(Node::Args args)355 std::shared_ptr<Node> MakeUnknownNode(Node::Args args) {
356 return std::make_shared<Unknown>(std::move(args));
357 }
358
AddNode(Node::Factory factory,const string & name,const string & output_name)359 std::shared_ptr<Node> Model::AddNode(Node::Factory factory, const string& name,
360 const string& output_name) {
361 // The name captures the sequence of iterators joined by `::`. We use the full
362 // sequence as the key in the lookup table, but only the last element of the
363 // sequence as the name node.
364 std::vector<string> tokens =
365 str_util::Split(name, ':', str_util::SkipEmpty());
366 // The output name might contain an index. We need to strip it to make it
367 // possible for the model to successfully identify the output node.
368 string sanitized_output_name = output_name;
369 if (str_util::EndsWith(output_name, "]")) {
370 sanitized_output_name = output_name.substr(0, output_name.rfind('['));
371 }
372 std::shared_ptr<Node> output;
373 mutex_lock l(mu_);
374 auto it = lookup_table_.find(sanitized_output_name);
375 if (it != lookup_table_.end()) {
376 output = it->second;
377 }
378 std::shared_ptr<Node> node = factory({id_counter_++, tokens.back(), output});
379 if (!output_) {
380 output_ = node;
381 }
382 if (output) {
383 VLOG(3) << "Adding " << node->name() << "(id:" << node->id()
384 << ") as input for " << output->name() << "(id:" << output->id()
385 << ")";
386 output->add_input(node);
387 } else {
388 VLOG(3) << "Adding " << node->name() << "(id:" << node->id() << ")";
389 }
390 collect_resource_usage_ =
391 collect_resource_usage_ || node->has_tunable_parameters();
392 lookup_table_.insert(std::make_pair(name, node));
393 return node;
394 }
395
AddProcessingTime(const string & name,int64 delta)396 void Model::AddProcessingTime(const string& name, int64 delta) {
397 tf_shared_lock l(mu_);
398 auto node = gtl::FindOrNull(lookup_table_, name);
399 if (node) {
400 (*node)->add_processing_time(delta);
401 }
402 }
403
404 // The optimization algorithm starts by setting all tunable parallelism
405 // parameters to 1. It then repeatedly identifies the parameter whose increase
406 // in parallelism decreases the output time the most. This process is repeated
407 // until all parameters reach their maximum values or the projected output time
408 // is less than or equal to the processing time needed to produce an element
409 // divided by CPU budget.
Optimize(int64 cpu_budget)410 void Model::Optimize(int64 cpu_budget) {
411 std::shared_ptr<Node> snapshot;
412 {
413 tf_shared_lock lock(mu_);
414 snapshot = output_->Snapshot(nullptr);
415 }
416 const int64 processing_time = ProcessingTime(snapshot);
417 auto parameters = CollectTunableParameters(snapshot);
418 for (auto& parameter : parameters) {
419 parameter->value = 1;
420 }
421 while (true) {
422 const int64 output_time = OutputTime(snapshot);
423 bool all_max = true;
424 for (auto& parameter : parameters) {
425 if (parameter->value < parameter->max) {
426 all_max = false;
427 break;
428 }
429 }
430 if (output_time < processing_time / cpu_budget || all_max) {
431 break;
432 }
433 int64 best_delta = -1;
434 Parameter* best_parameter = nullptr;
435 for (auto& parameter : parameters) {
436 if (parameter->value == parameter->max) {
437 continue;
438 }
439 parameter->value++;
440 int64 delta = output_time - OutputTime(snapshot);
441 if (delta > best_delta) {
442 best_delta = delta;
443 best_parameter = parameter.get();
444 }
445 parameter->value--;
446 }
447 if (!best_parameter) {
448 // This should never happen because we are using a model snapshot and
449 // the output time is monotonically decreasing w.r.t. parallelism.
450 LOG(WARNING) << "Failed to find a tunable parameter that would "
451 "decrease the output time, aborting the current "
452 "optimization attempt.";
453 return;
454 }
455 best_parameter->value++;
456 }
457 VLOG(2) << "Number of tunable parameters: " << parameters.size();
458 for (auto& parameter : parameters) {
459 VLOG(2) << "Setting tunable parameter: " << parameter->value;
460 mutex_lock l(*parameter->state->mu);
461 parameter->state->value = parameter->value;
462 parameter->state->cond_var->notify_all();
463 }
464 }
465
RecordElement(const string & name)466 void Model::RecordElement(const string& name) {
467 tf_shared_lock l(mu_);
468 auto node = gtl::FindOrNull(lookup_table_, name);
469 if (node) {
470 (*node)->record_element();
471 }
472 }
473
RecordStart(const string & name,bool stop_output)474 void Model::RecordStart(const string& name, bool stop_output) {
475 tf_shared_lock l(mu_);
476 auto node = gtl::FindOrNull(lookup_table_, name);
477 if (collect_resource_usage_ && node) {
478 int64 now_nanos = Env::Default()->NowNanos();
479 if (stop_output && (*node)->output()) {
480 (*node)->output()->record_stop(now_nanos);
481 }
482 (*node)->record_start(now_nanos);
483 }
484 }
485
RecordStop(const string & name,bool start_output)486 void Model::RecordStop(const string& name, bool start_output) {
487 tf_shared_lock l(mu_);
488 auto node = gtl::FindOrNull(lookup_table_, name);
489 if (collect_resource_usage_ && node) {
490 int64 now_nanos = Env::Default()->NowNanos();
491 (*node)->record_stop(now_nanos);
492 if (start_output && (*node)->output()) {
493 (*node)->output()->record_start(now_nanos);
494 }
495 }
496 }
497
RemoveNode(const string & name)498 void Model::RemoveNode(const string& name) {
499 mutex_lock l(mu_);
500 auto node = gtl::FindOrNull(lookup_table_, name);
501 if (node) {
502 if ((*node)->output()) {
503 (*node)->output()->remove_input(*node);
504 }
505 VLOG(3) << "Removing " << (*node)->name() << "(id:" << (*node)->id() << ")";
506 remove_node_hook_(*node);
507 }
508 lookup_table_.erase(name);
509 }
510
CollectTunableParameters(std::shared_ptr<Node> node)511 std::vector<std::shared_ptr<Parameter>> Model::CollectTunableParameters(
512 std::shared_ptr<Node> node) {
513 std::vector<std::shared_ptr<Parameter>> parameters;
514 node->CollectTunableParameters(¶meters);
515 return parameters;
516 }
517
OutputTime(std::shared_ptr<Node> node)518 int64 Model::OutputTime(std::shared_ptr<Node> node) {
519 std::vector<int64> input_times(1, 0);
520 return node->OutputTime(&input_times);
521 }
522
ProcessingTime(std::shared_ptr<Node> node)523 int64 Model::ProcessingTime(std::shared_ptr<Node> node) {
524 return node->ProcessingTime();
525 }
526
527 } // namespace model
528 } // namespace data
529 } // namespace tensorflow
530