• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
18 
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/costmodel_manager.h"
23 #include "tensorflow/core/common_runtime/executor.h"
24 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
25 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
26 #include "tensorflow/core/distributed_runtime/worker_env.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/collective.h"
29 #include "tensorflow/core/framework/cost_graph.pb.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/lib/core/refcount.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/protobuf/config.pb.h"
37 #include "tensorflow/core/protobuf/debug.pb.h"
38 #include "tensorflow/core/protobuf/worker.pb.h"
39 
40 namespace tensorflow {
41 
42 class ExecutorOpts;
43 class StepStatsCollector;
44 class RendezvousMgrInterface;
45 class DeviceMgr;
46 class WorkerSession;
47 
48 // GraphMgr keeps track of a set of graphs that are registered with a
49 // TensorFlow worker. Each registered graph is identified by a handle
50 // that is generated by GraphMgr and returned to the caller.
51 //
52 // After a successful registration, the caller executes a graph using
53 // the graph handle. Each execution is distinguished from others by a
54 // caller generated global unique id "step_id". Multiple executions
55 // can use the same graph concurrently and independently as long as
56 // "step_id" used are different.
57 //
58 // Multiple threads can call GraphMgr methods concurrently.
59 //
60 // E.g.,
61 //   GraphMgr gmgr(worker_env);
62 //   string handle;
63 //   TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b },
64 //   &handle));
65 //   GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) },
66 //                                { "b", Tensor({3, 4}) } };
67 //   GraphMgr::NamedTensors out = { { "c", Tensor() } };
68 //   TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out));
69 //   EXPECT_EQ(out["c"], Tensor({4, 6}));
70 class GraphMgr {
71  public:
72   explicit GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr);
73   ~GraphMgr();
74 
75   // Registers a graph. Fills in "handle". The registered graph retains a
76   // reference to cluster_flr to do cross process function calls.
77   Status Register(const string& handle, const GraphDef& gdef,
78                   WorkerSession* session, const GraphOptions& graph_options,
79                   const DebugOptions& debug_options,
80                   const ConfigProto& config_proto, int64 collective_graph_key,
81                   DistributedFunctionLibraryRuntime* cluster_flr,
82                   string* graph_handle);
83 
84   // Executes one step of a registered graph "handle".
85   //
86   // If "out" is not nullptr, "out" specifies all keys the execution
87   // should receive upon finish.
88   typedef std::map<string, Tensor> NamedTensors;
89   typedef std::function<void(const Status&)> StatusCallback;
90   void ExecuteAsync(const string& handle, const int64 step_id,
91                     WorkerSession* session, const ExecutorOpts& opts,
92                     StepStatsCollector* collector,
93                     MutableRunGraphResponseWrapper* response,
94                     CancellationManager* cancellation_manager,
95                     const NamedTensors& in, StatusCallback done);
96 
97   Status SendInputs(const int64 step_id, const NamedTensors& in);
98   Status RecvOutputs(const int64 step_id, NamedTensors* out);
99   void RecvOutputsAsync(const int64 step_id, NamedTensors* out,
100                         StatusCallback done);
101 
102   // Deregisters a graph.
103   Status Deregister(const string& handle);
104 
105   // Deregister all graphs.
106   Status DeregisterAll();
107 
108  private:
109   typedef GraphMgr ME;
110 
111   struct ExecutionUnit {
112     std::unique_ptr<Graph> graph = nullptr;
113     Device* device = nullptr;               // not owned.
114     Executor* root = nullptr;               // not owned.
115     FunctionLibraryRuntime* lib = nullptr;  // not owned.
116     // Build the cost model if this value is strictly positive.
117     int64 build_cost_model = 0;
118   };
119 
120   struct Item : public core::RefCounted {
121     // TODO(zhifengc): Keeps a copy of the original graph if the need arises.
122     // TODO(zhifengc): Stats, updated by multiple runs potentially.
123     // TODO(zhifengc): Dup-detection. Ensure step_id only run once.
124     ~Item() override;
125 
126     // Session handle.
127     string session;
128 
129     // Graph handle.
130     string handle;
131 
132     std::unique_ptr<FunctionLibraryDefinition> lib_def;
133     // Owns the FunctionLibraryRuntime objects needed to execute functions, one
134     // per device.
135     std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
136     // A graph is partitioned over multiple devices.  Each partition
137     // has a root executor which may call into the runtime library.
138     std::vector<ExecutionUnit> units;
139 
140     // Used to deregister a cost model when cost model is required in graph
141     // manager.
142     GraphMgr* graph_mgr;
143 
144     int64 collective_graph_key;
145   };
146 
147   const WorkerEnv* worker_env_;  // Not owned.
148   const DeviceMgr* device_mgr_;
149 
150   CostModelManager cost_model_manager_;
151 
152   // Owned.
153   mutex mu_;
154   int64 next_id_ TF_GUARDED_BY(mu_) = 0;
155 
156   // If true, blocks until device has finished all queued operations in a step.
157   bool sync_on_finish_ = true;
158 
159   // Table mapping graph handles to registered graphs.
160   //
161   // TODO(zhifengc): If the client does not call Deregister, we'll
162   // lose memory over time. We should implement a timeout-based
163   // mechanism to gc these graphs.
164   std::unordered_map<string, Item*> table_;
165 
166   void StartParallelExecutors(const string& handle, int64 step_id, Item* item,
167                               Rendezvous* rendezvous,
168                               CollectiveExecutor::Handle* ce_handle,
169                               StepStatsCollector* collector,
170                               CostGraphDef* cost_graph,
171                               CancellationManager* cancellation_manager,
172                               WorkerSession* session, StatusCallback done);
173 
174   // Don't attempt to process cost models unless explicitly requested for at
175   // least one of the items.
176   bool skip_cost_models_ = true;
177 
178   void BuildCostModel(Item* item, StepStatsCollector* collector,
179                       CostGraphDef* cost_graph);
180 
181   Status InitItem(const string& handle, const GraphDef& gdef,
182                   WorkerSession* session, const GraphOptions& graph_options,
183                   const DebugOptions& debug_options,
184                   const ConfigProto& config_proto, int64 collective_graph_key,
185                   DistributedFunctionLibraryRuntime* cluster_flr, Item* item);
186 
187   Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,
188                                          Graph* graph, Device* device);
189 
190   TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr);
191 };
192 
193 }  // end namespace tensorflow
194 
195 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
196