1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_ 17 18 #include <algorithm> 19 #include <cstddef> 20 #include <map> 21 #include <memory> 22 #include <queue> 23 #include <string> 24 #include <thread> 25 #include <vector> 26 27 #include "tensorflow/core/common_runtime/device_factory.h" 28 #include "tensorflow/core/common_runtime/function.h" 29 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 30 #include "tensorflow/core/framework/rendezvous.h" 31 #include "tensorflow/core/lib/gtl/inlined_vector.h" 32 #include "tensorflow/core/lib/gtl/map_util.h" 33 #include "tensorflow/core/lib/gtl/stl_util.h" 34 #include "tensorflow/core/platform/mutex.h" 35 #include "tensorflow/core/platform/thread_annotations.h" 36 #include "tensorflow/core/public/version.h" 37 38 namespace tensorflow { 39 40 // A unit of execution for the EagerExecutor class below. Example subclasses 41 // encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one 42 // device to another. 43 class EagerNode { 44 public: 45 explicit EagerNode(uint64 id); 46 ~EagerNode()47 virtual ~EagerNode() {} 48 49 // Runs the computation corresponding to this node and blocks till the 50 // execution is done. 51 virtual Status Run() = 0; 52 53 // An id unique to the TFE_Context under which this node is created. Allocated 54 // monotonically. 55 const uint64 id; 56 }; 57 58 // A class for handling async execution (see TFE_ContextSetAsync). 59 // Note that this class is thread-safe. 60 // TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the 61 // device of the input handle. Fix that. 62 // TODO(agarwal): On error, mark all affected handles as corrupted. 63 // TODO(agarwal): Implement support for control dependencies. 64 // TODO(agarwal): Support out-of-order execution and dispatching multiple 65 // EagerNode in parallel. 66 // TODO(agarwal): Implement optimizations over EagerNode traces. 67 class EagerExecutor { 68 public: 69 ~EagerExecutor(); 70 71 // This is called whenever async mode is enabled. Note that it may be called 72 // multiple times as different calling threads may switch async mode on or off 73 // independently. 74 void EnableAsync(); 75 76 // Helper function to create monotonically increasing ids unique to this 77 // object. 78 uint64 NextId(); 79 80 // Schedules `node` for execution. 81 // Note that Add must be called in monotonically increasing order of node->id. 82 void Add(EagerNode* node); 83 84 // Causes the caller to block till node with id `node_id` has finished 85 // execution. 86 Status WaitFor(uint64 node_id); 87 88 // Blocks till all currently pending ops are done. 89 Status WaitForAllPendingNodes(); 90 91 // Clears all currently set errors which re-enables async execution. 92 void ClearError(); 93 94 // Returns Status based on any errors that occurred during async execution. 95 Status status(); 96 97 private: 98 // Starts execution of pending EagerNodes. This function loops till 99 // thread_done_ is set to true. If any errors are encontered, these are set 100 // inside `status_`. The loop blocks anytime there are no pending nodes, or if 101 // `status_` is not ok. 102 void Run(); 103 104 Status WaitImpl(bool wait_all, uint64 node_id); 105 106 mutex node_queue_mutex_; 107 108 // Used to signal that some EagerNodes are pending execution. 109 condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_); 110 111 // Queue of pending EagerNodes. 112 std::queue<EagerNode*> node_queue_ GUARDED_BY(node_queue_mutex_); 113 114 // `status_` is set based on any errors raised during execution of a 115 // EagerNode. It remains set until ClearError is called. 116 Status status_ GUARDED_BY(node_queue_mutex_); 117 118 // Map from id of a EagerNode to condition_variables (not owned by the map). 119 // These condition_variables are notified and removed when that EagerNode is 120 // done executing, or if an error is found in execution of any EagerNode. 121 std::multimap<uint64, condition_variable*> node_done_notifications_ 122 GUARDED_BY(node_queue_mutex_); 123 124 // Thread object that calls the `Run` method. Currently we use only one thread 125 // for executing the EagerNodes one-by-one. 126 std::unique_ptr<Thread> thread_ GUARDED_BY(node_queue_mutex_); 127 128 // Indicates that `thread_` should stop as soon as it is done executing the 129 // current EagerNode. 130 bool thread_done_ GUARDED_BY(node_queue_mutex_) = false; 131 132 mutex next_id_mutex_; 133 uint64 next_id_ GUARDED_BY(next_id_mutex_) = 1; 134 }; 135 136 } // namespace tensorflow 137 138 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_ 139