1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
17
18 #include <chrono> // NOLINT(build/c++11)
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/build_graph_options.h"
22 #include "tensorflow/core/common_runtime/constant_folding.h"
23 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/device_mgr.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/common_runtime/graph_optimizer.h"
28 #include "tensorflow/core/common_runtime/memory_types.h"
29 #include "tensorflow/core/common_runtime/metrics.h"
30 #include "tensorflow/core/common_runtime/optimization_registry.h"
31 #include "tensorflow/core/common_runtime/process_util.h"
32 #include "tensorflow/core/common_runtime/rendezvous_util.h"
33 #include "tensorflow/core/common_runtime/step_stats_collector.h"
34 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
35 #include "tensorflow/core/framework/cancellation.h"
36 #include "tensorflow/core/framework/collective.h"
37 #include "tensorflow/core/framework/log_memory.h"
38 #include "tensorflow/core/framework/node_def.pb.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/versions.pb.h"
41 #include "tensorflow/core/graph/graph.h"
42 #include "tensorflow/core/graph/graph_constructor.h"
43 #include "tensorflow/core/graph/graph_partition.h"
44 #include "tensorflow/core/graph/validate.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/strings/stringprintf.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/types.h"
51 #include "tensorflow/core/protobuf/worker.pb.h"
52 #include "tensorflow/core/util/env_var.h"
53
54 namespace tensorflow {
55
GraphMgr(const WorkerEnv * worker_env,DeviceMgr * device_mgr)56 GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
57 : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
58 // The default value of sync_on_finish will be flipped soon and this
59 // environment variable will be removed as well.
60 Status status =
61 ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
62 if (!status.ok()) {
63 LOG(ERROR) << status.error_message();
64 }
65 }
66
~GraphMgr()67 GraphMgr::~GraphMgr() {
68 for (auto p : table_) p.second->Unref();
69 }
70
~Item()71 GraphMgr::Item::~Item() {
72 for (const auto& unit : this->units) {
73 CHECK_NOTNULL(unit.device);
74 if (!graph_mgr->skip_cost_models_) {
75 graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph);
76 }
77 delete unit.root;
78 unit.device->op_segment()->RemoveHold(this->session);
79 }
80 }
81
82 // NOTE: node->device_name() is not set by GraphConstructor. We
83 // expects that NodeDef in GraphDef given to workers fully specifies
84 // device names.
SplitByDevice(const Node * node)85 static string SplitByDevice(const Node* node) {
86 return node->assigned_device_name();
87 }
88
89 // Validates "gdef" device specifications.
ValidateGraphDefForDevices(const GraphDef & gdef)90 static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
91 DeviceNameUtils::ParsedName parsed;
92 for (const auto& ndef : gdef.node()) {
93 if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) {
94 return errors::InvalidArgument("Missing device name in: ",
95 FormatNodeDefForError(ndef));
96 }
97 }
98 return Status::OK();
99 }
100
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)101 Status GraphMgr::DecorateAndPublishGraphForDebug(
102 const DebugOptions& debug_options, Graph* graph, Device* device) {
103 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
104 TF_RETURN_IF_ERROR(
105 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
106 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
107 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
108 return Status::OK();
109 }
110
111 // Creates executors given a graph definition "gdef" of a "session".
112 // If a node in "gdef" is shared by other graphs in "session", the
113 // same op kernel is reused. E.g., typically a params node is shared
114 // by multiple graphs in a session.
115 //
116 // If "gdef" is assigned to multiple devices, extra nodes (e.g.,
117 // send/recv nodes) maybe added. The extra nodes' name are generated
118 // by calling "new_name(old_name)".
119 //
120 // "executors" are filled with one executor per device if success and
121 // the caller takes the ownership of returned executors.
InitItem(const string & session,const GraphDef & gdef,const GraphOptions & graph_options,const DebugOptions & debug_options,int64 collective_graph_key,DistributedFunctionLibraryRuntime * cluster_flr,Item * item)122 Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
123 const GraphOptions& graph_options,
124 const DebugOptions& debug_options,
125 int64 collective_graph_key,
126 DistributedFunctionLibraryRuntime* cluster_flr,
127 Item* item) {
128 item->session = session;
129 item->collective_graph_key = collective_graph_key;
130 item->lib_def.reset(
131 new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
132
133 TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
134
135 if (gdef.versions().producer() >= 5) {
136 // Validate the graph: we assume that merging two valid graphs
137 // should maintain graph validity.
138 TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *item->lib_def));
139 }
140
141 item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
142 device_mgr_, worker_env_->env, gdef.versions().producer(),
143 item->lib_def.get(), graph_options.optimizer_options(),
144 worker_env_->compute_pool, cluster_flr));
145
146 // Constructs the graph out of "gdef".
147 Graph graph(OpRegistry::Global());
148 GraphConstructorOptions opts;
149 opts.allow_internal_ops = true;
150 opts.expect_device_spec = true;
151 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
152
153 // Splits "graph" into multiple subgraphs by device names.
154 std::unordered_map<string, GraphDef> partitions;
155 PartitionOptions popts;
156 popts.node_to_loc = SplitByDevice;
157 popts.new_name = [this](const string& prefix) {
158 mutex_lock l(mu_);
159 return strings::StrCat(prefix, "_G", next_id_++);
160 };
161 popts.get_incarnation = [this](const string& name) -> int64 {
162 Device* device = nullptr;
163 Status s = device_mgr_->LookupDevice(name, &device);
164 if (s.ok()) {
165 return device->attributes().incarnation();
166 } else {
167 return PartitionOptions::kIllegalIncarnation;
168 }
169 };
170 popts.flib_def = &graph.flib_def();
171 popts.control_flow_added = true;
172 popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
173 TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
174 if (popts.scheduling_for_recvs) {
175 TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
176 }
177
178 std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
179 for (const auto& partition : partitions) {
180 std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
181 GraphConstructorOptions device_opts;
182 // There are internal operations (e.g., send/recv) that we now allow.
183 device_opts.allow_internal_ops = true;
184 device_opts.expect_device_spec = true;
185 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
186 device_graph.get()));
187 partition_graphs.emplace(partition.first, std::move(device_graph));
188 }
189
190 GraphOptimizationPassOptions optimization_options;
191 optimization_options.flib_def = item->lib_def.get();
192 optimization_options.partition_graphs = &partition_graphs;
193 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
194 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
195
196 LocalExecutorParams params;
197
198 item->units.reserve(partitions.size());
199 item->graph_mgr = this;
200 const auto& optimizer_opts = graph_options.optimizer_options();
201 GraphOptimizer optimizer(optimizer_opts);
202 for (auto& p : partition_graphs) {
203 const string& device_name = p.first;
204 std::unique_ptr<Graph>& subgraph = p.second;
205 item->units.resize(item->units.size() + 1);
206 ExecutionUnit* unit = &(item->units.back());
207
208 // Find the device.
209 Status s = device_mgr_->LookupDevice(device_name, &unit->device);
210 if (!s.ok()) {
211 // Remove the empty unit from the item as the item destructor wants all
212 // units to have valid devices.
213 item->units.pop_back();
214 return s;
215 }
216
217 // Give the device an opportunity to rewrite its subgraph.
218 TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph));
219
220 // Top-level nodes in the graph uses the op segment to cache
221 // kernels. Therefore, as long as the executor is alive, we need
222 // to ensure the kernels cached for the session are alive.
223 auto opseg = unit->device->op_segment();
224 opseg->AddHold(session);
225
226 // Function library runtime.
227 FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name());
228 if (lib == nullptr) {
229 return errors::InvalidArgument("Cannot find FLR for device: ",
230 unit->device->name());
231 }
232
233 // Construct the root executor for the subgraph.
234 params.device = unit->device;
235 params.function_library = lib;
236 params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
237 OpKernel** kernel) {
238 // NOTE(mrry): We must not share function kernels (implemented
239 // using `CallOp`) between subgraphs, because `CallOp::handle_`
240 // is tied to a particular subgraph. Even if the function itself
241 // is stateful, the `CallOp` that invokes it is not.
242 if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
243 return lib->CreateKernel(ndef, kernel);
244 }
245 auto create_fn = [lib, &ndef](OpKernel** kernel) {
246 return lib->CreateKernel(ndef, kernel);
247 };
248 // Kernels created for subgraph nodes need to be cached. On
249 // cache miss, create_fn() is invoked to create a kernel based
250 // on the function library here + global op registry.
251 return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn);
252 };
253 params.delete_kernel = [lib](OpKernel* kernel) {
254 if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
255 delete kernel;
256 }
257 };
258
259 optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
260 /*shape_map=*/nullptr);
261
262 // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
263 if (!debug_options.debug_tensor_watch_opts().empty()) {
264 TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
265 debug_options, subgraph.get(), params.device));
266 }
267
268 TF_RETURN_IF_ERROR(
269 EnsureMemoryTypes(DeviceType(unit->device->device_type()),
270 unit->device->name(), subgraph.get()));
271 unit->graph = subgraph.get();
272 unit->build_cost_model = graph_options.build_cost_model();
273 if (unit->build_cost_model > 0) {
274 skip_cost_models_ = false;
275 }
276 TF_RETURN_IF_ERROR(
277 NewLocalExecutor(params, std::move(subgraph), &unit->root));
278 }
279 return Status::OK();
280 }
281
Register(const string & session,const GraphDef & gdef,const GraphOptions & graph_options,const DebugOptions & debug_options,int64 collective_graph_key,DistributedFunctionLibraryRuntime * cluster_flr,string * handle)282 Status GraphMgr::Register(const string& session, const GraphDef& gdef,
283 const GraphOptions& graph_options,
284 const DebugOptions& debug_options,
285 int64 collective_graph_key,
286 DistributedFunctionLibraryRuntime* cluster_flr,
287 string* handle) {
288 Item* item = new Item;
289 Status s = InitItem(session, gdef, graph_options, debug_options,
290 collective_graph_key, cluster_flr, item);
291 if (!s.ok()) {
292 item->Unref();
293 return s;
294 }
295
296 // Inserts one item into table_.
297 {
298 mutex_lock l(mu_);
299 *handle = strings::Printf("%016llx", ++next_id_);
300 item->handle = *handle;
301 CHECK(table_.insert({*handle, item}).second);
302 }
303 return Status::OK();
304 }
305
Deregister(const string & handle)306 Status GraphMgr::Deregister(const string& handle) {
307 Item* item = nullptr;
308 // Removes one item from table_.
309 {
310 mutex_lock l(mu_);
311 auto iter = table_.find(handle);
312 if (iter == table_.end()) {
313 return errors::Aborted("Graph handle is not found: ", handle,
314 ". Possibly, this worker just restarted.");
315 }
316 item = iter->second;
317 table_.erase(iter);
318 }
319 item->Unref();
320 return Status::OK();
321 }
322
DeregisterAll()323 Status GraphMgr::DeregisterAll() {
324 std::vector<Item*> items;
325 // Removes all items from table_.
326 {
327 mutex_lock l(mu_);
328 for (const auto& entry : table_) {
329 items.push_back(entry.second);
330 }
331 table_.clear();
332 }
333 for (auto item : items) {
334 item->Unref();
335 }
336 return Status::OK();
337 }
338
SendInputs(const int64 step_id,const NamedTensors & in)339 Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
340 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
341 std::vector<string> keys;
342 std::vector<Tensor> tensors_to_send;
343 keys.reserve(in.size());
344 tensors_to_send.reserve(in.size());
345 for (const auto& p : in) {
346 keys.push_back(p.first);
347 tensors_to_send.push_back(p.second);
348 }
349 Status s =
350 SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
351 rendezvous->Unref();
352 return s;
353 }
354
RecvOutputs(const int64 step_id,NamedTensors * out)355 Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
356 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
357 Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
358 rendezvous->Unref();
359 if (!s.ok()) {
360 // Failing to fetch the outputs should not be possible, so rewrite the error
361 // status to an INTERNAL error.
362 s = errors::Internal("Failed to fetch outputs for step ", step_id,
363 ". (Original error message: ", s.ToString(), ")");
364 }
365 return s;
366 }
367
RecvOutputsAsync(const int64 step_id,NamedTensors * out,StatusCallback done)368 void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
369 StatusCallback done) {
370 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
371 std::vector<string> keys;
372 std::vector<Tensor>* received_keys = new std::vector<Tensor>;
373 keys.reserve(out->size());
374 received_keys->reserve(out->size());
375 for (const auto& p : *out) {
376 keys.push_back(p.first);
377 received_keys->push_back(p.second);
378 }
379 RecvOutputsFromRendezvousAsync(
380 rendezvous, nullptr, {}, keys, received_keys,
381 [done, rendezvous, received_keys, out, keys](const Status s) {
382 rendezvous->Unref();
383 for (int i = 0; i < keys.size(); ++i) {
384 (*out)[keys[i]] = (*received_keys)[i];
385 }
386 delete received_keys;
387 done(s);
388 });
389 }
390
ExecuteAsync(const string & handle,const int64 step_id,WorkerSession * session,const ExecutorOpts & opts,StepStatsCollector * collector,MutableRunGraphResponseWrapper * response,CancellationManager * cancellation_manager,const NamedTensors & in,StatusCallback done)391 void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
392 WorkerSession* session, const ExecutorOpts& opts,
393 StepStatsCollector* collector,
394 MutableRunGraphResponseWrapper* response,
395 CancellationManager* cancellation_manager,
396 const NamedTensors& in, StatusCallback done) {
397 const uint64 start_time_usecs = Env::Default()->NowMicros();
398 // Lookup an item. Holds one ref while executing.
399 Item* item = nullptr;
400 {
401 mutex_lock l(mu_);
402 auto iter = table_.find(handle);
403 if (iter != table_.end()) {
404 item = iter->second;
405 item->Ref();
406 }
407 }
408
409 if (item == nullptr) {
410 done(errors::Aborted("Graph handle is not found: ", handle));
411 return;
412 }
413
414 CostGraphDef* cost_graph = nullptr;
415 if (response != nullptr) {
416 cost_graph = response->mutable_cost_graph();
417 if (opts.record_partition_graphs()) {
418 for (const ExecutionUnit& unit : item->units) {
419 GraphDef graph_def;
420 unit.graph->ToGraphDef(&graph_def);
421 response->AddPartitionGraph(graph_def);
422 }
423 }
424 }
425
426 RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
427 Status s = rendezvous->Initialize(session);
428 CollectiveExecutor::Handle* ce_handle =
429 item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
430 ? new CollectiveExecutor::Handle(
431 worker_env_->collective_executor_mgr->FindOrCreate(step_id),
432 true)
433 : nullptr;
434 // Sends values specified by the caller.
435 if (s.ok()) {
436 std::vector<string> keys;
437 std::vector<Tensor> tensors_to_send;
438 keys.reserve(in.size());
439 tensors_to_send.reserve(in.size());
440 for (auto& p : in) {
441 keys.push_back(p.first);
442 tensors_to_send.push_back(p.second);
443 }
444 s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
445 }
446
447 if (!s.ok()) {
448 done(s);
449 delete ce_handle;
450 item->Unref();
451 rendezvous->Unref();
452 return;
453 }
454
455 StartParallelExecutors(
456 handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
457 cancellation_manager,
458 [item, rendezvous, ce_handle, done, start_time_usecs](const Status& s) {
459 done(s);
460 metrics::UpdateGraphExecTime(Env::Default()->NowMicros() -
461 start_time_usecs);
462 rendezvous->Unref();
463 item->Unref();
464 delete ce_handle;
465 });
466 }
467
StartParallelExecutors(const string & handle,int64 step_id,Item * item,Rendezvous * rendezvous,CollectiveExecutor::Handle * ce_handle,StepStatsCollector * collector,CostGraphDef * cost_graph,CancellationManager * cancellation_manager,StatusCallback done)468 void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
469 Item* item, Rendezvous* rendezvous,
470 CollectiveExecutor::Handle* ce_handle,
471 StepStatsCollector* collector,
472 CostGraphDef* cost_graph,
473 CancellationManager* cancellation_manager,
474 StatusCallback done) {
475 const int num_units = item->units.size();
476 CHECK_GE(num_units, 1);
477 ScopedStepContainer* step_container = new ScopedStepContainer(
478 step_id,
479 [this](const string& name) { device_mgr_->ClearContainers({name}); });
480 // NOTE: Transfer one ref of rendezvous and item.
481 ExecutorBarrier* barrier =
482 new ExecutorBarrier(num_units, rendezvous,
483 [this, item, collector, cost_graph, step_container,
484 done](const Status& s) {
485 BuildCostModel(item, collector, cost_graph);
486 done(s);
487 delete step_container;
488 });
489 Executor::Args args;
490 args.step_id = step_id;
491 args.rendezvous = rendezvous;
492 args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
493 args.cancellation_manager = cancellation_manager;
494 args.stats_collector = collector;
495 args.step_container = step_container;
496 args.sync_on_finish = sync_on_finish_;
497 if (LogMemory::IsEnabled()) {
498 LogMemory::RecordStep(args.step_id, handle);
499 }
500 thread::ThreadPool* pool = worker_env_->compute_pool;
501 using std::placeholders::_1;
502 // Line below is equivalent to this code, but does one less indirect call:
503 // args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
504 auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
505 for (const auto& unit : item->units) {
506 // TODO(zhengxq): if the device picks its own threadpool, we need to assign
507 // less threads to the main compute pool by default.
508 thread::ThreadPool* device_thread_pool =
509 unit.device->tensorflow_device_thread_pool();
510 if (!device_thread_pool) {
511 args.runner = default_runner;
512 } else {
513 args.runner =
514 std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1);
515 }
516 unit.root->RunAsync(args, barrier->Get());
517 }
518 }
519
BuildCostModel(Item * item,StepStatsCollector * collector,CostGraphDef * cost_graph)520 void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
521 CostGraphDef* cost_graph) {
522 if (collector && !skip_cost_models_) {
523 // Build the cost model
524 std::unordered_map<string, const Graph*> device_to_graph;
525 for (const auto& unit : item->units) {
526 if (unit.build_cost_model > 0) {
527 device_to_graph[unit.device->name()] = unit.graph;
528 }
529 }
530 collector->BuildCostModel(&cost_model_manager_, device_to_graph);
531
532 if (cost_graph != nullptr) {
533 for (const auto& unit : item->units) {
534 cost_model_manager_.AddToCostGraphDef(unit.graph, cost_graph)
535 .IgnoreError();
536 }
537 }
538 }
539 }
540
541 } // end namespace tensorflow
542