• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
18 
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/costmodel_manager.h"
23 #include "tensorflow/core/common_runtime/executor.h"
24 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
25 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
26 #include "tensorflow/core/distributed_runtime/worker_env.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/collective.h"
29 #include "tensorflow/core/framework/cost_graph.pb.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/lib/core/refcount.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/protobuf/config.pb.h"
37 #include "tensorflow/core/protobuf/debug.pb.h"
38 #include "tensorflow/core/protobuf/worker.pb.h"
39 
40 namespace tensorflow {
41 
42 class ExecutorOpts;
43 class StepStatsCollector;
44 class RendezvousMgrInterface;
45 class DeviceMgr;
46 struct WorkerSession;
47 
48 // GraphMgr keeps track of a set of graphs that are registered with a
49 // TensorFlow worker. Each registered graph is identified by a handle
50 // that is generated by GraphMgr and returned to the caller.
51 //
52 // After a successful registration, the caller executes a graph using
53 // the graph handle. Each execution is distinguished from others by a
54 // caller generated global unique id "step_id". Multiple executions
55 // can use the same graph concurrently and independently as long as
56 // "step_id" used are different.
57 //
58 // Multiple threads can call GraphMgr methods concurrently.
59 //
60 // E.g.,
61 //   GraphMgr gmgr(worker_env);
62 //   string handle;
63 //   TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b },
64 //   &handle));
65 //   GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) },
66 //                                { "b", Tensor({3, 4}) } };
67 //   GraphMgr::NamedTensors out = { { "c", Tensor() } };
68 //   TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out));
69 //   EXPECT_EQ(out["c"], Tensor({4, 6}));
70 class GraphMgr {
71  public:
72   explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr);
73   ~GraphMgr();
74 
75   // Registers a graph. Fills in "handle". The registered graph retains a
76   // reference to cluster_flr to do cross process function calls.
77   Status Register(const string& session, const GraphDef& gdef,
78                   const GraphOptions& graph_options,
79                   const DebugOptions& debug_options, int64 collective_graph_key,
80                   DistributedFunctionLibraryRuntime* cluster_flr,
81                   string* handle);
82 
83   // Executes one step of a registered graph "handle".
84   //
85   // If "out" is not nullptr, "out" specifies all keys the execution
86   // should receive upon finish.
87   typedef std::map<string, Tensor> NamedTensors;
88   typedef std::function<void(const Status&)> StatusCallback;
89   void ExecuteAsync(const string& handle, const int64 step_id,
90                     WorkerSession* session, const ExecutorOpts& opts,
91                     StepStatsCollector* collector,
92                     MutableRunGraphResponseWrapper* response,
93                     CancellationManager* cancellation_manager,
94                     const NamedTensors& in, StatusCallback done);
95 
96   Status SendInputs(const int64 step_id, const NamedTensors& in);
97   Status RecvOutputs(const int64 step_id, NamedTensors* out);
98   void RecvOutputsAsync(const int64 step_id, NamedTensors* out,
99                         StatusCallback done);
100 
101   // Deregisters a graph.
102   Status Deregister(const string& handle);
103 
104   // Deregister all graphs.
105   Status DeregisterAll();
106 
107  private:
108   typedef GraphMgr ME;
109 
110   struct ExecutionUnit {
111     Graph* graph = nullptr;                 // not owned.
112     Device* device = nullptr;               // not owned.
113     Executor* root = nullptr;               // not owned.
114     FunctionLibraryRuntime* lib = nullptr;  // not owned.
115     // Build the cost model if this value is strictly positive.
116     int64 build_cost_model = 0;
117   };
118 
119   struct Item : public core::RefCounted {
120     // TODO(zhifengc): Keeps a copy of the original graph if the need arises.
121     // TODO(zhifengc): Stats, updated by multiple runs potentially.
122     // TODO(zhifengc): Dup-detection. Ensure step_id only run once.
123     ~Item() override;
124 
125     // Session handle.
126     string session;
127 
128     // Graph handle.
129     string handle;
130 
131     std::unique_ptr<FunctionLibraryDefinition> lib_def;
132     // Owns the FunctionLibraryRuntime objects needed to execute functions, one
133     // per device.
134     std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
135     // A graph is partitioned over multiple devices.  Each partition
136     // has a root executor which may call into the runtime library.
137     std::vector<ExecutionUnit> units;
138 
139     // Used to deregister a cost model when cost model is required in graph
140     // manager.
141     GraphMgr* graph_mgr;
142 
143     int64 collective_graph_key;
144   };
145 
146   const WorkerEnv* worker_env_;  // Not owned.
147   DeviceMgr* device_mgr_;
148 
149   CostModelManager cost_model_manager_;
150 
151   // Owned.
152   mutex mu_;
153   int64 next_id_ GUARDED_BY(mu_) = 0;
154 
155   // If true, blocks until device has finished all queued operations in a step.
156   bool sync_on_finish_ = true;
157 
158   // Table mapping graph handles to registered graphs.
159   //
160   // TODO(zhifengc): If the client does not call Deregister, we'll
161   // lose memory over time. We should implement a timeout-based
162   // mechanism to gc these graphs.
163   std::unordered_map<string, Item*> table_;
164 
165   void StartParallelExecutors(const string& handle, int64 step_id, Item* item,
166                               Rendezvous* rendezvous,
167                               CollectiveExecutor::Handle* ce_handle,
168                               StepStatsCollector* collector,
169                               CostGraphDef* cost_graph,
170                               CancellationManager* cancellation_manager,
171                               StatusCallback done);
172 
173   // Don't attempt to process cost models unless explicitly requested for at
174   // least one of the items.
175   bool skip_cost_models_ = true;
176 
177   void BuildCostModel(Item* item, StepStatsCollector* collector,
178                       CostGraphDef* cost_graph);
179 
180   Status InitItem(const string& session, const GraphDef& gdef,
181                   const GraphOptions& graph_options,
182                   const DebugOptions& debug_options, int64 collective_graph_key,
183                   DistributedFunctionLibraryRuntime* cluster_flr, Item* item);
184 
185   Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,
186                                          Graph* graph, Device* device);
187 
188   TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr);
189 };
190 
191 }  // end namespace tensorflow
192 
193 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
194