1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ 18 19 #include <unordered_map> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/costmodel_manager.h" 23 #include "tensorflow/core/common_runtime/executor.h" 24 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 25 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 26 #include "tensorflow/core/distributed_runtime/worker_env.h" 27 #include "tensorflow/core/framework/cancellation.h" 28 #include "tensorflow/core/framework/collective.h" 29 #include "tensorflow/core/framework/cost_graph.pb.h" 30 #include "tensorflow/core/framework/function.h" 31 #include "tensorflow/core/lib/core/refcount.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/macros.h" 34 #include "tensorflow/core/platform/mutex.h" 35 #include "tensorflow/core/platform/types.h" 36 #include "tensorflow/core/protobuf/config.pb.h" 37 #include "tensorflow/core/protobuf/debug.pb.h" 38 #include "tensorflow/core/protobuf/worker.pb.h" 39 40 namespace tensorflow { 41 42 class ExecutorOpts; 43 class StepStatsCollector; 44 class RendezvousMgrInterface; 45 class DeviceMgr; 46 class WorkerSession; 47 48 // GraphMgr keeps track of a set of graphs that are registered with a 49 // TensorFlow worker. Each registered graph is identified by a handle 50 // that is generated by GraphMgr and returned to the caller. 51 // 52 // After a successful registration, the caller executes a graph using 53 // the graph handle. Each execution is distinguished from others by a 54 // caller generated global unique id "step_id". Multiple executions 55 // can use the same graph concurrently and independently as long as 56 // "step_id" used are different. 57 // 58 // Multiple threads can call GraphMgr methods concurrently. 59 // 60 // E.g., 61 // GraphMgr gmgr(worker_env); 62 // string handle; 63 // TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b }, 64 // &handle)); 65 // GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) }, 66 // { "b", Tensor({3, 4}) } }; 67 // GraphMgr::NamedTensors out = { { "c", Tensor() } }; 68 // TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out)); 69 // EXPECT_EQ(out["c"], Tensor({4, 6})); 70 class GraphMgr { 71 public: 72 explicit GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr); 73 ~GraphMgr(); 74 75 // Registers a graph. Fills in "handle". The registered graph retains a 76 // reference to cluster_flr to do cross process function calls. 77 Status Register(const string& handle, const GraphDef& gdef, 78 WorkerSession* session, const GraphOptions& graph_options, 79 const DebugOptions& debug_options, 80 const ConfigProto& config_proto, int64 collective_graph_key, 81 DistributedFunctionLibraryRuntime* cluster_flr, 82 string* graph_handle); 83 84 // Executes one step of a registered graph "handle". 85 // 86 // If "out" is not nullptr, "out" specifies all keys the execution 87 // should receive upon finish. 88 typedef std::map<string, Tensor> NamedTensors; 89 typedef std::function<void(const Status&)> StatusCallback; 90 void ExecuteAsync(const string& handle, const int64 step_id, 91 WorkerSession* session, const ExecutorOpts& opts, 92 StepStatsCollector* collector, 93 MutableRunGraphResponseWrapper* response, 94 CancellationManager* cancellation_manager, 95 const NamedTensors& in, StatusCallback done); 96 97 Status SendInputs(const int64 step_id, const NamedTensors& in); 98 Status RecvOutputs(const int64 step_id, NamedTensors* out); 99 void RecvOutputsAsync(const int64 step_id, NamedTensors* out, 100 StatusCallback done); 101 102 // Deregisters a graph. 103 Status Deregister(const string& handle); 104 105 // Deregister all graphs. 106 Status DeregisterAll(); 107 108 private: 109 typedef GraphMgr ME; 110 111 struct ExecutionUnit { 112 std::unique_ptr<Graph> graph = nullptr; 113 Device* device = nullptr; // not owned. 114 Executor* root = nullptr; // not owned. 115 FunctionLibraryRuntime* lib = nullptr; // not owned. 116 // Build the cost model if this value is strictly positive. 117 int64 build_cost_model = 0; 118 }; 119 120 struct Item : public core::RefCounted { 121 // TODO(zhifengc): Keeps a copy of the original graph if the need arises. 122 // TODO(zhifengc): Stats, updated by multiple runs potentially. 123 // TODO(zhifengc): Dup-detection. Ensure step_id only run once. 124 ~Item() override; 125 126 // Session handle. 127 string session; 128 129 // Graph handle. 130 string handle; 131 132 std::unique_ptr<FunctionLibraryDefinition> lib_def; 133 // Owns the FunctionLibraryRuntime objects needed to execute functions, one 134 // per device. 135 std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr; 136 // A graph is partitioned over multiple devices. Each partition 137 // has a root executor which may call into the runtime library. 138 std::vector<ExecutionUnit> units; 139 140 // Used to deregister a cost model when cost model is required in graph 141 // manager. 142 GraphMgr* graph_mgr; 143 144 int64 collective_graph_key; 145 }; 146 147 const WorkerEnv* worker_env_; // Not owned. 148 const DeviceMgr* device_mgr_; 149 150 CostModelManager cost_model_manager_; 151 152 // Owned. 153 mutex mu_; 154 int64 next_id_ TF_GUARDED_BY(mu_) = 0; 155 156 // If true, blocks until device has finished all queued operations in a step. 157 bool sync_on_finish_ = true; 158 159 // Table mapping graph handles to registered graphs. 160 // 161 // TODO(zhifengc): If the client does not call Deregister, we'll 162 // lose memory over time. We should implement a timeout-based 163 // mechanism to gc these graphs. 164 std::unordered_map<string, Item*> table_; 165 166 void StartParallelExecutors(const string& handle, int64 step_id, Item* item, 167 Rendezvous* rendezvous, 168 CollectiveExecutor::Handle* ce_handle, 169 StepStatsCollector* collector, 170 CostGraphDef* cost_graph, 171 CancellationManager* cancellation_manager, 172 WorkerSession* session, StatusCallback done); 173 174 // Don't attempt to process cost models unless explicitly requested for at 175 // least one of the items. 176 bool skip_cost_models_ = true; 177 178 void BuildCostModel(Item* item, StepStatsCollector* collector, 179 CostGraphDef* cost_graph); 180 181 Status InitItem(const string& handle, const GraphDef& gdef, 182 WorkerSession* session, const GraphOptions& graph_options, 183 const DebugOptions& debug_options, 184 const ConfigProto& config_proto, int64 collective_graph_key, 185 DistributedFunctionLibraryRuntime* cluster_flr, Item* item); 186 187 Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options, 188 Graph* graph, Device* device); 189 190 TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr); 191 }; 192 193 } // end namespace tensorflow 194 195 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ 196