1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ 16 #define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ 17 18 #include <list> 19 #include <memory> 20 #include <string> 21 // TODO(b/114492873): Move this include into core/platform. 22 #include <thread> // NOLINT 23 #include <utility> 24 #include <vector> 25 26 #include "absl/container/flat_hash_map.h" 27 #include "tensorflow/core/framework/cancellation.h" 28 #include "tensorflow/core/framework/metrics.h" 29 #include "tensorflow/core/framework/model.pb.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/lib/gtl/cleanup.h" 32 #include "tensorflow/core/lib/gtl/map_util.h" 33 #include "tensorflow/core/lib/histogram/histogram.h" 34 #include "tensorflow/core/lib/random/random.h" 35 #include "tensorflow/core/platform/cpu_info.h" 36 #include "tensorflow/core/platform/env.h" 37 #include "tensorflow/core/platform/mutex.h" 38 #include "tensorflow/core/platform/path.h" 39 40 namespace tensorflow { 41 namespace data { 42 namespace model { 43 44 // A constant that can be used to enable auto-tuning. 45 constexpr int64 kAutotune = -1; 46 constexpr char kParallelism[] = "parallelism"; 47 constexpr char kBufferSize[] = "buffer_size"; 48 49 // A key used to identify the input time of the model. 50 constexpr char kModelInputTimeKey[] = "model_input_time"; 51 52 enum class TraversalOrder { 53 BFS = 0, 54 REVERSE_BFS = 1, 55 }; 56 57 // Represents thread-safe state that can be shared between an input pipeline and 58 // the performance model. 59 struct SharedState { 60 public: SharedStateSharedState61 SharedState(int64 value, std::shared_ptr<mutex> mu, 62 std::shared_ptr<condition_variable> cond_var) 63 : value(value), 64 mu(std::move(mu)), 65 cond_var(std::move(cond_var)), 66 tunable(value == kAutotune) {} 67 68 double value; 69 const std::shared_ptr<mutex> mu; 70 const std::shared_ptr<condition_variable> cond_var; 71 const bool tunable; 72 }; 73 74 // Represents a parameter. 75 struct Parameter { ParameterParameter76 Parameter(const string& name, std::shared_ptr<SharedState> state, double min, 77 double max) 78 : name(name), 79 // Sometimes non-autotune nodes (with `autotune_=false`) may contain 80 // parameters (for example inputs of parallel interleave dataset which 81 // are not in the current cycle). To avoid unrealistic situation 82 // (say `buffer_size=-1` or `parallelism=-1`) in the optimization 83 // computation, if the state value is `kAutotune=-1` (just to indicate 84 // the `SharedState` is tunable), we initialize the parameter value to 85 // be the minimal value of the state. 86 value(state->value == kAutotune ? min : state->value), 87 min(min), 88 max(max), 89 state(std::move(state)) {} 90 91 // Human-readable name of the parameter. 92 const string name; 93 94 // Identifies the model value of the parameter. This can be different from 95 // the actual value (e.g. during optimization search). 96 double value; 97 98 // Identifies the minimum value of the parameter. 99 const double min; 100 101 // Identifies the maximum value of the parameter. 102 const double max; 103 104 // Shared state of the parameter. 105 std::shared_ptr<SharedState> state; 106 }; 107 108 std::shared_ptr<Parameter> MakeParameter(const string& name, 109 std::shared_ptr<SharedState> state, 110 double min, double max); 111 112 // Abstract representation of a TensorFlow input pipeline node. It collects 113 // information about inputs to this node, processing time spent executing the 114 // node logic, number of elements produced by the node, various other 115 // information (e.g. batch size or execution parallelism). 116 // 117 // Developers of tf.data transformations are not expected to interact with 118 // this class directly. Boiler plate code for creating the abstract 119 // representation of the input pipeline and collecting common information has 120 // been added to the implementation of `DatasetBase` and `DatasetBaseIterator` 121 // respectively. 122 // 123 // In addition, `DatasetBaseIterator` provides wrappers that can be used for 124 // transformation-specific information collection. The `SetMetadata` wrapper 125 // can be used to pass arbitrary metadata to the modeling framework, while the 126 // `StartWork` and `StopWork` wrappers should be used to correctly account for 127 // processing time of multi-threaded transformation that yield the CPU; such 128 // transformations should invoke `StartWork()` when a transformation thread 129 // starts executing (e.g. when created or woken up) and `StopWork()` when a 130 // transformation thread stops executing (e.g. when returning or waiting). 131 class Node { 132 public: 133 // Arguments for `Node` constructor. 134 struct Args { 135 int64 id; 136 string name; 137 std::shared_ptr<Node> output; 138 }; 139 140 using Factory = std::function<std::shared_ptr<Node>(Args)>; 141 using NodeVector = std::vector<std::shared_ptr<Node>>; 142 using NodePairList = 143 std::list<std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>>; 144 Node(Args args)145 explicit Node(Args args) 146 : id_(args.id), 147 name_(std::move(args.name)), 148 autotune_(true), 149 buffered_bytes_(0), 150 buffered_elements_(0), 151 bytes_consumed_(0), 152 bytes_produced_(0), 153 num_elements_(0), 154 processing_time_(0), 155 record_metrics_(true), 156 metrics_(name_), 157 output_(args.output.get()) {} 158 ~Node()159 virtual ~Node() { 160 // Clear the sub-nodes instead of relying on implicit shared pointer 161 // destructor to avoid potential stack overflow when the tree is deep. 162 std::deque<std::shared_ptr<Node>> queue; 163 { 164 mutex_lock l(mu_); 165 while (inputs_.size() > 0) { 166 queue.push_back(inputs_.front()); 167 inputs_.pop_front(); 168 } 169 } 170 while (!queue.empty()) { 171 auto node = queue.back(); 172 queue.pop_back(); 173 { 174 mutex_lock l(node->mu_); 175 while (node->inputs_.size() > 0) { 176 queue.push_back(node->inputs_.front()); 177 node->inputs_.pop_front(); 178 } 179 } 180 } 181 182 FlushMetrics(); 183 } 184 185 // Adds an input. add_input(std::shared_ptr<Node> node)186 void add_input(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_) { 187 mutex_lock l(mu_); 188 inputs_.push_back(node); 189 } 190 191 // Increments the aggregate processing time by the given delta. add_processing_time(int64 delta)192 void add_processing_time(int64 delta) TF_LOCKS_EXCLUDED(mu_) { 193 processing_time_ += delta; 194 } 195 196 // Returns an indication whether autotuning is enabled for this node. autotune()197 bool autotune() const TF_LOCKS_EXCLUDED(mu_) { 198 return autotune_; 199 } 200 201 // Returns the number of bytes stored in this node's buffer. buffered_bytes()202 int64 buffered_bytes() const TF_LOCKS_EXCLUDED(mu_) { 203 return buffered_bytes_; 204 } 205 206 // Returns the number of elements stored in this node's buffer. buffered_elements()207 int64 buffered_elements() const TF_LOCKS_EXCLUDED(mu_) { 208 return buffered_elements_; 209 } 210 211 // Returns the number of bytes consumed by the node. bytes_consumed()212 int64 bytes_consumed() const TF_LOCKS_EXCLUDED(mu_) { 213 return bytes_consumed_; 214 } 215 216 // Returns the number of bytes produced by the node. bytes_produced()217 int64 bytes_produced() const TF_LOCKS_EXCLUDED(mu_) { 218 return bytes_produced_; 219 } 220 221 // Indicates whether the node has tunable parameters. has_tunable_parameters()222 bool has_tunable_parameters() const TF_LOCKS_EXCLUDED(mu_) { 223 tf_shared_lock l(mu_); 224 for (const auto& pair : parameters_) { 225 if (pair.second->state->tunable) return true; 226 } 227 return false; 228 } 229 230 // Returns the unique node ID. id()231 int64 id() const TF_LOCKS_EXCLUDED(mu_) { return id_; } 232 233 // Returns the node inputs. inputs()234 std::list<std::shared_ptr<Node>> inputs() const TF_LOCKS_EXCLUDED(mu_) { 235 tf_shared_lock l(mu_); 236 return inputs_; 237 } 238 239 // Returns a longer node name that is guaranteed to be unique. long_name()240 string long_name() const { return strings::StrCat(name_, "(id:", id_, ")"); } 241 242 // Returns the node name. name()243 const string& name() const { return name_; } 244 245 // Returns the number of elements produced by the node. num_elements()246 int64 num_elements() const TF_LOCKS_EXCLUDED(mu_) { 247 return num_elements_; 248 } 249 250 // Returns the node output. output()251 Node* output() const { return output_; } 252 253 // Returns the parameter value. parameter_value(const string & name)254 double parameter_value(const string& name) const TF_LOCKS_EXCLUDED(mu_) { 255 tf_shared_lock l(mu_); 256 return parameters_.at(name)->state->value; 257 } 258 259 // Returns the aggregate processing time. processing_time()260 int64 processing_time() const TF_LOCKS_EXCLUDED(mu_) { 261 return processing_time_; 262 } 263 264 // Records that the node consumed the given number of bytes. record_bytes_consumed(int64 num_bytes)265 void record_bytes_consumed(int64 num_bytes) { bytes_consumed_ += num_bytes; } 266 267 // Records that the node produced the given number of bytes. record_bytes_produced(int64 num_bytes)268 void record_bytes_produced(int64 num_bytes) { bytes_produced_ += num_bytes; } 269 270 // Records the change in this node's buffer. record_buffer_event(int64 bytes_delta,int64 elements_delta)271 void record_buffer_event(int64 bytes_delta, int64 elements_delta) { 272 buffered_bytes_ += bytes_delta; 273 buffered_elements_ += elements_delta; 274 } 275 276 // Records that the node produced an element. record_element()277 void record_element() TF_LOCKS_EXCLUDED(mu_) { 278 num_elements_++; 279 } 280 281 // Records that a node thread has started executing. record_start(int64 time_nanos)282 void record_start(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) { 283 DCHECK_EQ(work_start_, 0); 284 work_start_ = time_nanos; 285 } 286 287 // Records that a node thread has stopped executing. record_stop(int64 time_nanos)288 void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) { 289 // TODO(jsimsa): Use DCHECK_NE(work_start_, 0) here. 290 if (work_start_ != 0) { 291 processing_time_ += time_nanos - work_start_; 292 work_start_ = 0; 293 } else { 294 VLOG(1) << "Encountered a stop event without a matching start event."; 295 } 296 } 297 298 // Returns whether work is currently being recorded, i.e. whether we are 299 // currently between a `record_start` and a `record_stop`. is_recording()300 bool is_recording() TF_LOCKS_EXCLUDED(mu_) { return work_start_ > 0; } 301 302 // Removes an input. remove_input(std::shared_ptr<Node> input)303 void remove_input(std::shared_ptr<Node> input) TF_LOCKS_EXCLUDED(mu_) { 304 mutex_lock l(mu_); 305 inputs_.remove(input); 306 } 307 308 // Sets the value that determines whether autotuning is enabled for this node. set_autotune(bool autotune)309 void set_autotune(bool autotune) TF_LOCKS_EXCLUDED(mu_) { 310 autotune_.store(autotune); 311 } 312 313 // Given the average time between events when the elements in the buffer are 314 // produced (`producer_time`), the average time between events when elements 315 // in the buffer are consumed (`consumer_time`) and the buffer size, the 316 // method computes the expected time an consumer event will have to wait. 317 // 318 // The wait time is approximated as the product of the probability the buffer 319 // will be empty and the time it takes to produce an element into the buffer. 320 // 321 // The formula used for computing the probability is derived by modeling the 322 // problem as an M/M/1/K queue 323 // (https://en.wikipedia.org/wiki/Birth%E2%80%93death_process#M/M/1/K_queue). 324 // 325 // Collects derivatives of `ComputeWaitTime` w.r.t `producer_time`, 326 // `consumer_time' and `buffer_size` if the corresponding pointers are not 327 // `nullptr`. 328 static double ComputeWaitTime(const double& producer_time, 329 const double& consumer_time, 330 const double& buffer_size, 331 double* producer_time_derivative, 332 double* consumer_time_derivative, 333 double* buffer_size_derivative); 334 335 // Collects tunable parameters in the subtree rooted in this node. 336 void CollectTunableParameters( 337 absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const 338 TF_LOCKS_EXCLUDED(mu_); 339 340 // Returns a human-readable representation of this node. 341 string DebugString() const TF_LOCKS_EXCLUDED(mu_); 342 343 // Flushes the metrics recorded by this node. 344 void FlushMetrics() TF_LOCKS_EXCLUDED(mu_); 345 346 // Returns the per-element output time for this node and if `gradients` is not 347 // `nullptr`, collects the output time gradient w.r.t. tunable parameters of 348 // the subtree rooted in this node. 349 double OutputTime(absl::flat_hash_map<string, double>* input_times, 350 absl::flat_hash_map<string, double>* gradients) const 351 TF_LOCKS_EXCLUDED(mu_); 352 353 // Returns a copy of this node, making a deep copy of its inputs and a 354 // shallow copy of its tunable parameters. 355 // 356 // The purpose for this method is to allow the model optimization logic to 357 // operate over immutable state while allowing concurrent model updates. 358 std::shared_ptr<Node> Snapshot() const TF_LOCKS_EXCLUDED(mu_); 359 360 // Returns the per-element processing time spent in this node. 361 double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_); 362 363 // Returns the total number of bytes buffered in all nodes in the subtree for 364 // which autotuning is enabled. 365 double TotalBufferedBytes() const TF_LOCKS_EXCLUDED(mu_); 366 367 // Collects the total buffer limit of all nodes in the subtree for which 368 // autotuning is enabled. This number represents the amount of memory that 369 // would be used by the subtree nodes if all of their buffers were full. 370 double TotalMaximumBufferedBytes() const TF_LOCKS_EXCLUDED(mu_); 371 372 // Returns the per-element CPU time spent in the subtree rooted in this node. 373 // If `processing_times` is not `nullptr`, collects the per-element CPU time 374 // spent in each node of the subtree. 375 double TotalProcessingTime( 376 absl::flat_hash_map<string, double>* processing_times) 377 TF_LOCKS_EXCLUDED(mu_); 378 379 // Recursively produces a proto for this node and its subtree. 380 virtual Status ToProto(ModelProto::Node* node_proto) const; 381 382 // Recursively restores a node and its subtree from the proto. 383 static Status FromProto(ModelProto::Node node_proto, 384 std::shared_ptr<Node> output, 385 std::shared_ptr<Node>* node); 386 387 protected: 388 // Used for (incrementally) recording metrics. The class is thread-safe. 389 class Metrics { 390 public: Metrics(const string & name)391 explicit Metrics(const string& name) 392 : bytes_consumed_counter_(metrics::GetTFDataBytesConsumedCounter(name)), 393 bytes_produced_counter_(metrics::GetTFDataBytesProducedCounter(name)), 394 num_elements_counter_(metrics::GetTFDataElementsCounter(name)), 395 recorded_bytes_consumed_(0), 396 recorded_bytes_produced_(0), 397 recorded_num_elements_(0) {} 398 399 // Expects the total number of bytes consumed and records the delta since 400 // last invocation. record_bytes_consumed(int64 total_bytes)401 void record_bytes_consumed(int64 total_bytes) { 402 int64 delta = 403 total_bytes - recorded_bytes_consumed_.exchange(total_bytes); 404 bytes_consumed_counter_->IncrementBy(delta); 405 } 406 407 // Expects the total number of bytes produced and records the delta since 408 // last invocation. record_bytes_produced(int64 total_bytes)409 void record_bytes_produced(int64 total_bytes) { 410 int64 delta = 411 total_bytes - recorded_bytes_produced_.exchange(total_bytes); 412 bytes_produced_counter_->IncrementBy(delta); 413 } 414 415 // Expects the total number of elements produced and records the delta since 416 // last invocation. record_num_elements(int64 total_elements)417 void record_num_elements(int64 total_elements) { 418 int64 delta = 419 total_elements - recorded_num_elements_.exchange(total_elements); 420 num_elements_counter_->IncrementBy(delta); 421 } 422 423 private: 424 monitoring::CounterCell* const bytes_consumed_counter_; 425 monitoring::CounterCell* const bytes_produced_counter_; 426 monitoring::CounterCell* const num_elements_counter_; 427 std::atomic<int64> recorded_bytes_consumed_; 428 std::atomic<int64> recorded_bytes_produced_; 429 std::atomic<int64> recorded_num_elements_; 430 }; 431 432 // Returns the number of inputs. num_inputs()433 int64 num_inputs() const TF_SHARED_LOCKS_REQUIRED(mu_) { 434 int64 num_inputs = 0; 435 for (auto& input : inputs_) { 436 // Inputs for which autotuning is disabled are excluded. 437 if (input->autotune()) { 438 ++num_inputs; 439 } 440 } 441 return num_inputs; 442 } 443 444 // Creates a clone of this node. 445 virtual std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const 446 TF_SHARED_LOCKS_REQUIRED(mu_) = 0; 447 448 // Returns the average size of an element buffered in this node. 449 double AverageBufferedElementSize() const TF_SHARED_LOCKS_REQUIRED(mu_); 450 451 // Returns the sum of per-element output time for the tunable inputs of this 452 // node. 453 double OutputTimeForInputs( 454 const absl::flat_hash_map<string, double>& output_times) const 455 TF_SHARED_LOCKS_REQUIRED(mu_); 456 457 // Returns the sum of output time gradient w.r.t. input time for the tunable 458 // inputs of this node. 459 double OutputTimeGradientsForInputs( 460 const absl::flat_hash_map<string, double>& output_time_gradients) const 461 TF_SHARED_LOCKS_REQUIRED(mu_); 462 463 // Computes the input time for this node and stores it in `input_times`. 464 virtual void InputTimeLocked(absl::flat_hash_map<string, double>* input_times) 465 const TF_SHARED_LOCKS_REQUIRED(mu_) = 0; 466 467 // Computes the per-element output time for this node and stores it in 468 // `output_times`. If `gradients` is not `nullptr`, computes the output time 469 // gradient w.r.t. tunable parameters of the subtree rooted in this node and 470 // stores it in `gradients`, also computes the output time gradient w.r.t. 471 // input time and stores it in `output_time_gradients`. 472 virtual void OutputTimeLocked( 473 const absl::flat_hash_map<string, double>& input_times, 474 absl::flat_hash_map<string, double>* gradients, 475 absl::flat_hash_map<string, double>* output_times, 476 absl::flat_hash_map<string, double>* output_time_gradients) const 477 TF_SHARED_LOCKS_REQUIRED(mu_) = 0; 478 479 // Returns the sum of per-element processing time for the inputs of this node 480 // by adding values for input nodes in `total_processing_times`. Processing 481 // time for a given input is a weighted combination of a statistic based on 482 // history of input processing time and the actual time. This is done to 483 // improve accuracy of processing time estimation for newly created inputs. 484 // 485 // Uniform distribution of per-element processing times across different 486 // inputs is assumed. 487 double TotalProcessingTimeForInputs( 488 const absl::flat_hash_map<string, double>& total_processing_times) 489 TF_SHARED_LOCKS_REQUIRED(mu_); 490 491 // Returns the per-element processing time spent in this node. 492 double SelfProcessingTimeLocked() const TF_SHARED_LOCKS_REQUIRED(mu_); 493 494 // Computes the per-element CPU time spent in the subtree rooted in this node 495 // and stores it in `total_processing_times`. If `processing_times` is not 496 // `nullptr`, collects the per-element CPU time spent in each node of the 497 // subtree. 498 virtual void TotalProcessingTimeLocked( 499 absl::flat_hash_map<string, double>* processing_times, 500 absl::flat_hash_map<string, double>* total_processing_times) 501 TF_SHARED_LOCKS_REQUIRED(mu_) = 0; 502 503 // Returns a vector of nodes of the subtree rooted in this node. The nodes are 504 // either in breadth-first search or reverse breadth-first search order 505 // depending on the `order` argument. The nodes are collected based on the 506 // results of the `collect_node` predicate: if the predicate returns `false` 507 // for a given node, then the subtree rooted in this node is excluded. The 508 // root node itself is not collected. 509 NodeVector CollectNodes(TraversalOrder order, 510 bool collect_node(const std::shared_ptr<Node>)) const 511 TF_SHARED_LOCKS_REQUIRED(mu_); 512 513 // Collect tunable parameters on the nodes which have recorded elements. 514 void CollectTunableParametersHelper( 515 absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const 516 TF_SHARED_LOCKS_REQUIRED(mu_); 517 518 // Build up debug string for the node and store in the debug strings map. 519 void DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings) 520 const TF_SHARED_LOCKS_REQUIRED(mu_); 521 522 // Copy the node and add the (input, copy) pairs to the NodePairList. 523 std::shared_ptr<Node> SnapshotHelper(std::shared_ptr<Node> cloned_output, 524 NodePairList* node_pairs) const; 525 526 // Compute total buffered bytes for the node and store in the total bytes map. 527 void TotalBufferedBytesHelper( 528 absl::flat_hash_map<string, double>* total_bytes) const 529 TF_SHARED_LOCKS_REQUIRED(mu_); 530 531 // Compute total maximum buffered bytes for the node and store in the total 532 // bytes map. 533 void TotalMaximumBufferedBytesHelper( 534 absl::flat_hash_map<string, double>* total_bytes) const 535 TF_SHARED_LOCKS_REQUIRED(mu_); 536 537 // Compute and return the maximum buffered bytes on the node itself. By 538 // default non-tunable nodes are assumed not to buffer any bytes, so the 539 // tunable nodes as subclasses are expected to override this method to ensure 540 // that the optimization algorithm respects the memory budget. 541 virtual double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_); 542 543 // Restores node from the proto. Note that this is not done recursively, i.e. 544 // input nodes are not restored. 545 static Status FromProtoHelper(ModelProto::Node node_proto, 546 std::shared_ptr<Node> node); 547 548 // Stores the time passed to the last call to `Node::record_start()` on the 549 // current thread. 550 // 551 // NOTE: This thread-local variable is shared between all instances of `Node` 552 // on which the same thread calls `record_start()` or `record_stop()`. It 553 // relies on the invariant that at most one `Node` can be "active" on a 554 // particular thread at any time. Therefore if `n->record_start()` is called 555 // on thread `t`, then `n->record_stop()` must be called before another call 556 // to `Node::record_start()` (for any node). 557 static thread_local int64 work_start_; // Will be initialized to zero. 558 559 mutable mutex mu_; 560 const int64 id_; 561 const string name_; 562 563 // Indicates whether the subtree rooted in this node should be included in 564 // autotuning. In particular, if this is `false`, then the subtree is excluded 565 // from computation of output time and processing time. 566 std::atomic<bool> autotune_; 567 std::atomic<int64> buffered_bytes_; 568 std::atomic<int64> buffered_elements_; 569 std::atomic<int64> bytes_consumed_; 570 std::atomic<int64> bytes_produced_; 571 std::atomic<int64> num_elements_; 572 std::atomic<int64> processing_time_; 573 std::atomic<bool> record_metrics_; 574 Metrics metrics_; 575 absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_ 576 TF_GUARDED_BY(mu_); 577 578 // Statistic of inputs processing time history. 579 double input_processing_time_sum_ = 0.0L; 580 int64 input_processing_time_count_ = 0; 581 582 // Inputs of this node. These can represent an iterator created from the input 583 // dataset but also other input iterators (e.g. created by the user-defined 584 // functions of `flat_map` or `interleave`). 585 std::list<std::shared_ptr<Node>> inputs_ TF_GUARDED_BY(mu_); 586 587 // The reference to the output node is not owned so that deletion of a 588 // node results in recursive deletion of the subtree rooted in the node. 589 Node* const output_; 590 }; 591 592 // InterleaveMany is used to model datasets whose inputs are used to create 593 // datasets whose elements are then interleaved. 594 std::shared_ptr<Node> MakeInterleaveManyNode(Node::Args args); 595 596 // AsyncInterleaveMany nodes are the asynchronous version of InterleaveMany 597 // nodes. 598 std::shared_ptr<Node> MakeAsyncInterleaveManyNode( 599 Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters); 600 601 // KnownMany nodes model datasets that synchronously consume known number of 602 // input element per output element. 603 std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio); 604 605 // AsyncKnownRatio nodes are the asynchronous version of KnownRate nodes. 606 std::shared_ptr<Node> MakeAsyncKnownRatioNode( 607 Node::Args args, double ratio, double memory_ratio, 608 std::vector<std::shared_ptr<Parameter>> parameters); 609 610 std::shared_ptr<Node> MakeAsyncKnownRatioNode( 611 Node::Args args, double ratio, 612 std::vector<std::shared_ptr<Parameter>> parameters); 613 614 // Source nodes represent data sources. 615 std::shared_ptr<Node> MakeSourceNode(Node::Args args); 616 617 // UnknownMany nodes represent datasets that synchronously consume an 618 // unknown number of input elements per output. 619 // 620 // Unlike KnownRatio nodes which expect the ratio between inputs and outputs is 621 // specified as a parameter, UnknownRatio estimates the ratio empirically. 622 std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args); 623 624 // Unknown nodes represent datasets for which we do not have a model. It acts 625 // as pass-through between inputs and output. 626 std::shared_ptr<Node> MakeUnknownNode(Node::Args args); 627 628 // Abstract representation of a TensorFlow input pipeline that can be used 629 // for collecting runtime information and optimizing performance. It collects 630 // runtime information about execution of the input pipeline that is used to 631 // create a performance model, which is in turn used to identify optimal values 632 // of tunable parameters. 633 // 634 // Developers of tf.data transformations are not expected to interact with this 635 // class directly. Boiler plate code for creating the abstract representation of 636 // the input pipeline and collecting runtime information has been added to the 637 // implementation of `DatasetBase` and `DatasetBaseIterator` respectively. 638 class Model { 639 public: 640 using OptimizationParams = ModelProto::OptimizationParams; 641 642 // Creates a new model. Model()643 Model() 644 : collect_resource_usage_(false), 645 optimization_period_ms_(kOptimizationPeriodMinMs) { 646 const char* save_dir = std::getenv("TF_DATA_AUTOTUNE_DEBUG_DIR"); 647 if (save_dir) { 648 save_dir_ = string(save_dir); 649 } 650 } 651 ~Model()652 ~Model() { 653 if (!save_dir_.empty()) { 654 save_thread_cancelled_ = true; 655 save_cond_var_.notify_all(); 656 } 657 } 658 659 // Indicates whether to collect resource usage. collect_resource_usage()660 bool collect_resource_usage() const { return collect_resource_usage_; } 661 662 // Returns a pointer to the model's output node. output()663 const std::shared_ptr<Node> output() { 664 mutex_lock l(mu_); 665 return output_; 666 } 667 668 // Adds a node with the given name and given parent. 669 void AddNode(Node::Factory factory, const string& name, 670 std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node) 671 TF_LOCKS_EXCLUDED(mu_); 672 673 // Uses the given algorithm and resource budgets to periodically perform the 674 // autotuning optimization. 675 // 676 // To terminate the execution of the optimization loop, the caller needs to 677 // invoke `cancellation_mgr->StartCancel()`. 678 Status OptimizeLoop(AutotuneAlgorithm algorithm, int64 cpu_budget, 679 int64 ram_budget, CancellationManager* cancellation_mgr); 680 681 // Uses the given algorithm and resource budgets to perform the autotuning 682 // optimization. 683 void Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget, int64 ram_budget, 684 double model_input_time); 685 686 // Removes the given node. 687 void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_); 688 689 // Produces a proto for this model. 690 Status ToProto(ModelProto* model_proto); 691 692 // Restores a model from the proto. 693 static Status FromProto(ModelProto model_proto, 694 std::unique_ptr<Model>* model); 695 696 // Saves this model with a given snapshot and its optimization parameters to a 697 // file. Note that the file directory must already exist. 698 Status Save(const string& fname, std::shared_ptr<Node> snapshot, 699 const OptimizationParams& optimization_params); 700 701 // Loads a model and its optimization parameters from a file with the given 702 // name. 703 static Status Load(const string& fname, std::unique_ptr<Model>* model, 704 OptimizationParams* optimization_params); 705 706 private: 707 static constexpr int64 kOptimizationPeriodMinMs = 10; 708 static constexpr int64 kOptimizationPeriodMaxMs = 709 60 * EnvTime::kSecondsToMillis; 710 711 // Maximum number of optimization snapshots kept in a buffer for saving. 712 static constexpr int64 kMaxNumBufferedOptimizeArgs = 100; 713 714 // Collects tunable parameters in the tree rooted in the given node, returning 715 // a mapping from a (unique) node name to a tunable parameter. 716 absl::flat_hash_map<string, std::shared_ptr<Parameter>> 717 CollectTunableParameters(std::shared_ptr<Node> node); 718 719 // Flushes metrics recorded by the model. 720 void FlushMetrics() TF_LOCKS_EXCLUDED(mu_); 721 722 // This optimization algorithm starts by setting all tunable parallelism 723 // parameters to the minimum value. It then repeatedly identifies the 724 // parameter whose increase in parallelism decreases the output time the most. 725 // This process is repeated until all parameters reach their maximum values or 726 // the projected output time is less than or equal to the processing time 727 // needed to produce an element divided by CPU budget. 728 void OptimizeHillClimb(std::shared_ptr<Node> snapshot, 729 const OptimizationParams& optimization_params); 730 731 // This optimization algorithm starts by setting all tunable parallelism 732 // parameters to the minimum value. It then improves current parameters by 733 // making a step in the direction opposite to the gradient of `OutputTime` and 734 // projecting resulting values on the feasible intervals. Improvement step is 735 // repeated until either the output time improvement is smaller than threshold 736 // value or the output time is less than the processing time needed to produce 737 // an element divided by CPU budget. 738 void OptimizeGradientDescent(std::shared_ptr<Node> snapshot, 739 const OptimizationParams& optimization_params); 740 741 // Collects the output time and if `gradients` is not `nullptr`, the output 742 // time gradient w.r.t. tunable parameters of the subtree rooted in the given 743 // node. 744 double OutputTime(std::shared_ptr<Node> node, double model_input_time, 745 absl::flat_hash_map<string, double>* gradients); 746 747 // Determines if we should stop the gradient descent optimization iterations 748 // based on number of increasable parameters, CPU budget, RAM budget and 749 // current resource usage. 750 bool ShouldStop( 751 int64 cpu_budget, int64 ram_budget, 752 const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters, 753 const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& 754 parallelism_parameters, 755 const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& 756 buffer_size_parameters, 757 std::shared_ptr<Node> snapshot, bool* cpu_budget_reached); 758 759 // Collects the processing time for the given node. 760 double TotalProcessingTime(std::shared_ptr<Node> node); 761 762 // Collects the total number of bytes buffered in all nodes in the subtree 763 // rooted in the given node for which autotuning is enabled. 764 double TotalBufferedBytes(std::shared_ptr<Node> node); 765 766 // Collects the total buffer limit of all nodes in the subtree rooted in the 767 // given node for which autotuning is enabled. This number represents the 768 // amount of memory that would be used by the subtree nodes if all of their 769 // buffers were full. 770 double TotalMaximumBufferedBytes(std::shared_ptr<Node> node); 771 772 // Starts a model saving thread if it hasn't started yet. 773 Status EnsureSaveLoopThreadStarted(); 774 775 // Periodically saves the state of optimization that is kept in 776 // `save_buffer_`. 777 // 778 // The saving loop is terminated when the model is destroyed. 779 Status SaveLoop(); 780 781 // Used for coordination between different input pipeline threads. Exclusive 782 // access is required only when adding or removing nodes. Concurrent access to 783 // existing nodes is protected by a node mutex. 784 mutex mu_; 785 // Used for coordinating the optimization loop and model modifications. 786 condition_variable optimize_cond_var_; 787 int64 id_counter_ TF_GUARDED_BY(mu_) = 1; 788 std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_); 789 790 // Indicates whether the modeling framework should collect resource usage 791 // (e.g. CPU, memory). The logic for collecting this information assumes that 792 // the collection is not repeatedly disabled and enabled. As a consequence, 793 // the implementation starts collecting resource usage when it encounters a 794 // tunable parameter (because the information is used for tuning the value of 795 // the parameter) and never stops. 796 std::atomic<bool> collect_resource_usage_; 797 798 // Determines the time the optimization loop should wait between 799 // running optimizations. 800 int64 optimization_period_ms_ TF_GUARDED_BY(mu_); 801 802 // Thread that runs the model saving loop. 803 std::unique_ptr<Thread> save_thread_ TF_GUARDED_BY(mu_); 804 805 // Used for coordinating the saving loop and model optimization. 806 condition_variable save_cond_var_; 807 808 // Indicates whether the save thread is cancelled. 809 bool save_thread_cancelled_ = false; 810 811 // Contains path to the model saving directory if saving is enabled, empty 812 // otherwise. 813 string save_dir_; 814 815 // Contains pairs of model snapshots and optimization parameters to be saved 816 // if model saving is enabled, empty otherwise. Buffer elements are pushed by 817 // `OptimizeLoop` and popped by `SaveLoop`. 818 std::deque<std::pair<std::shared_ptr<Node>, OptimizationParams>> save_buffer_ 819 TF_GUARDED_BY(mu_); 820 }; 821 822 } // namespace model 823 } // namespace data 824 } // namespace tensorflow 825 826 #endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ 827