• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
17 
18 #include <chrono>  // NOLINT(build/c++11)
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/build_graph_options.h"
22 #include "tensorflow/core/common_runtime/constant_folding.h"
23 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/device_mgr.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/common_runtime/graph_optimizer.h"
28 #include "tensorflow/core/common_runtime/memory_types.h"
29 #include "tensorflow/core/common_runtime/metrics.h"
30 #include "tensorflow/core/common_runtime/optimization_registry.h"
31 #include "tensorflow/core/common_runtime/process_util.h"
32 #include "tensorflow/core/common_runtime/rendezvous_util.h"
33 #include "tensorflow/core/common_runtime/step_stats_collector.h"
34 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
35 #include "tensorflow/core/framework/cancellation.h"
36 #include "tensorflow/core/framework/collective.h"
37 #include "tensorflow/core/framework/log_memory.h"
38 #include "tensorflow/core/framework/node_def.pb.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/versions.pb.h"
41 #include "tensorflow/core/graph/graph.h"
42 #include "tensorflow/core/graph/graph_constructor.h"
43 #include "tensorflow/core/graph/graph_partition.h"
44 #include "tensorflow/core/graph/validate.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/strings/stringprintf.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/types.h"
51 #include "tensorflow/core/protobuf/worker.pb.h"
52 #include "tensorflow/core/util/env_var.h"
53 
54 namespace tensorflow {
55 
GraphMgr(const WorkerEnv * worker_env,DeviceMgr * device_mgr)56 GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
57     : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
58   // The default value of sync_on_finish will be flipped soon and this
59   // environment variable will be removed as well.
60   Status status =
61       ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
62   if (!status.ok()) {
63     LOG(ERROR) << status.error_message();
64   }
65 }
66 
~GraphMgr()67 GraphMgr::~GraphMgr() {
68   for (auto p : table_) p.second->Unref();
69 }
70 
~Item()71 GraphMgr::Item::~Item() {
72   for (const auto& unit : this->units) {
73     CHECK_NOTNULL(unit.device);
74     if (!graph_mgr->skip_cost_models_) {
75       graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph);
76     }
77     delete unit.root;
78     unit.device->op_segment()->RemoveHold(this->session);
79   }
80 }
81 
82 // NOTE: node->device_name() is not set by GraphConstructor.  We
83 // expects that NodeDef in GraphDef given to workers fully specifies
84 // device names.
SplitByDevice(const Node * node)85 static string SplitByDevice(const Node* node) {
86   return node->assigned_device_name();
87 }
88 
89 // Validates "gdef" device specifications.
ValidateGraphDefForDevices(const GraphDef & gdef)90 static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
91   DeviceNameUtils::ParsedName parsed;
92   for (const auto& ndef : gdef.node()) {
93     if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) {
94       return errors::InvalidArgument("Missing device name in: ",
95                                      FormatNodeDefForError(ndef));
96     }
97   }
98   return Status::OK();
99 }
100 
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)101 Status GraphMgr::DecorateAndPublishGraphForDebug(
102     const DebugOptions& debug_options, Graph* graph, Device* device) {
103   std::unique_ptr<DebugGraphDecoratorInterface> decorator;
104   TF_RETURN_IF_ERROR(
105       DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
106   TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
107   TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
108   return Status::OK();
109 }
110 
111 // Creates executors given a graph definition "gdef" of a "session".
112 // If a node in "gdef" is shared by other graphs in "session", the
113 // same op kernel is reused. E.g., typically a params node is shared
114 // by multiple graphs in a session.
115 //
116 // If "gdef" is assigned to multiple devices, extra nodes (e.g.,
117 // send/recv nodes) maybe added. The extra nodes' name are generated
118 // by calling "new_name(old_name)".
119 //
120 // "executors" are filled with one executor per device if success and
121 // the caller takes the ownership of returned executors.
InitItem(const string & session,const GraphDef & gdef,const GraphOptions & graph_options,const DebugOptions & debug_options,int64 collective_graph_key,DistributedFunctionLibraryRuntime * cluster_flr,Item * item)122 Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
123                           const GraphOptions& graph_options,
124                           const DebugOptions& debug_options,
125                           int64 collective_graph_key,
126                           DistributedFunctionLibraryRuntime* cluster_flr,
127                           Item* item) {
128   item->session = session;
129   item->collective_graph_key = collective_graph_key;
130   item->lib_def.reset(
131       new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
132 
133   TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
134 
135   if (gdef.versions().producer() >= 5) {
136     // Validate the graph: we assume that merging two valid graphs
137     // should maintain graph validity.
138     TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *item->lib_def));
139   }
140 
141   item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
142       device_mgr_, worker_env_->env, gdef.versions().producer(),
143       item->lib_def.get(), graph_options.optimizer_options(),
144       worker_env_->compute_pool, cluster_flr));
145 
146   // Constructs the graph out of "gdef".
147   Graph graph(OpRegistry::Global());
148   GraphConstructorOptions opts;
149   opts.allow_internal_ops = true;
150   opts.expect_device_spec = true;
151   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
152 
153   // Splits "graph" into multiple subgraphs by device names.
154   std::unordered_map<string, GraphDef> partitions;
155   PartitionOptions popts;
156   popts.node_to_loc = SplitByDevice;
157   popts.new_name = [this](const string& prefix) {
158     mutex_lock l(mu_);
159     return strings::StrCat(prefix, "_G", next_id_++);
160   };
161   popts.get_incarnation = [this](const string& name) -> int64 {
162     Device* device = nullptr;
163     Status s = device_mgr_->LookupDevice(name, &device);
164     if (s.ok()) {
165       return device->attributes().incarnation();
166     } else {
167       return PartitionOptions::kIllegalIncarnation;
168     }
169   };
170   popts.flib_def = &graph.flib_def();
171   popts.control_flow_added = true;
172   popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
173   TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
174   if (popts.scheduling_for_recvs) {
175     TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
176   }
177 
178   std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
179   for (const auto& partition : partitions) {
180     std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
181     GraphConstructorOptions device_opts;
182     // There are internal operations (e.g., send/recv) that we now allow.
183     device_opts.allow_internal_ops = true;
184     device_opts.expect_device_spec = true;
185     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
186                                               device_graph.get()));
187     partition_graphs.emplace(partition.first, std::move(device_graph));
188   }
189 
190   GraphOptimizationPassOptions optimization_options;
191   optimization_options.flib_def = item->lib_def.get();
192   optimization_options.partition_graphs = &partition_graphs;
193   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
194       OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
195 
196   LocalExecutorParams params;
197 
198   item->units.reserve(partitions.size());
199   item->graph_mgr = this;
200   const auto& optimizer_opts = graph_options.optimizer_options();
201   GraphOptimizer optimizer(optimizer_opts);
202   for (auto& p : partition_graphs) {
203     const string& device_name = p.first;
204     std::unique_ptr<Graph>& subgraph = p.second;
205     item->units.resize(item->units.size() + 1);
206     ExecutionUnit* unit = &(item->units.back());
207 
208     // Find the device.
209     Status s = device_mgr_->LookupDevice(device_name, &unit->device);
210     if (!s.ok()) {
211       // Remove the empty unit from the item as the item destructor wants all
212       // units to have valid devices.
213       item->units.pop_back();
214       return s;
215     }
216 
217     // Give the device an opportunity to rewrite its subgraph.
218     TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph));
219 
220     // Top-level nodes in the graph uses the op segment to cache
221     // kernels. Therefore, as long as the executor is alive, we need
222     // to ensure the kernels cached for the session are alive.
223     auto opseg = unit->device->op_segment();
224     opseg->AddHold(session);
225 
226     // Function library runtime.
227     FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name());
228     if (lib == nullptr) {
229       return errors::InvalidArgument("Cannot find FLR for device: ",
230                                      unit->device->name());
231     }
232 
233     // Construct the root executor for the subgraph.
234     params.device = unit->device;
235     params.function_library = lib;
236     params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
237                                                  OpKernel** kernel) {
238       // NOTE(mrry): We must not share function kernels (implemented
239       // using `CallOp`) between subgraphs, because `CallOp::handle_`
240       // is tied to a particular subgraph. Even if the function itself
241       // is stateful, the `CallOp` that invokes it is not.
242       if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
243         return lib->CreateKernel(ndef, kernel);
244       }
245       auto create_fn = [lib, &ndef](OpKernel** kernel) {
246         return lib->CreateKernel(ndef, kernel);
247       };
248       // Kernels created for subgraph nodes need to be cached.  On
249       // cache miss, create_fn() is invoked to create a kernel based
250       // on the function library here + global op registry.
251       return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn);
252     };
253     params.delete_kernel = [lib](OpKernel* kernel) {
254       if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
255         delete kernel;
256       }
257     };
258 
259     optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
260                        /*shape_map=*/nullptr);
261 
262     // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
263     if (!debug_options.debug_tensor_watch_opts().empty()) {
264       TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
265           debug_options, subgraph.get(), params.device));
266     }
267 
268     TF_RETURN_IF_ERROR(
269         EnsureMemoryTypes(DeviceType(unit->device->device_type()),
270                           unit->device->name(), subgraph.get()));
271     unit->graph = subgraph.get();
272     unit->build_cost_model = graph_options.build_cost_model();
273     if (unit->build_cost_model > 0) {
274       skip_cost_models_ = false;
275     }
276     TF_RETURN_IF_ERROR(
277         NewLocalExecutor(params, std::move(subgraph), &unit->root));
278   }
279   return Status::OK();
280 }
281 
Register(const string & session,const GraphDef & gdef,const GraphOptions & graph_options,const DebugOptions & debug_options,int64 collective_graph_key,DistributedFunctionLibraryRuntime * cluster_flr,string * handle)282 Status GraphMgr::Register(const string& session, const GraphDef& gdef,
283                           const GraphOptions& graph_options,
284                           const DebugOptions& debug_options,
285                           int64 collective_graph_key,
286                           DistributedFunctionLibraryRuntime* cluster_flr,
287                           string* handle) {
288   Item* item = new Item;
289   Status s = InitItem(session, gdef, graph_options, debug_options,
290                       collective_graph_key, cluster_flr, item);
291   if (!s.ok()) {
292     item->Unref();
293     return s;
294   }
295 
296   // Inserts one item into table_.
297   {
298     mutex_lock l(mu_);
299     *handle = strings::Printf("%016llx", ++next_id_);
300     item->handle = *handle;
301     CHECK(table_.insert({*handle, item}).second);
302   }
303   return Status::OK();
304 }
305 
Deregister(const string & handle)306 Status GraphMgr::Deregister(const string& handle) {
307   Item* item = nullptr;
308   // Removes one item from table_.
309   {
310     mutex_lock l(mu_);
311     auto iter = table_.find(handle);
312     if (iter == table_.end()) {
313       return errors::Aborted("Graph handle is not found: ", handle,
314                              ". Possibly, this worker just restarted.");
315     }
316     item = iter->second;
317     table_.erase(iter);
318   }
319   item->Unref();
320   return Status::OK();
321 }
322 
DeregisterAll()323 Status GraphMgr::DeregisterAll() {
324   std::vector<Item*> items;
325   // Removes all items from table_.
326   {
327     mutex_lock l(mu_);
328     for (const auto& entry : table_) {
329       items.push_back(entry.second);
330     }
331     table_.clear();
332   }
333   for (auto item : items) {
334     item->Unref();
335   }
336   return Status::OK();
337 }
338 
SendInputs(const int64 step_id,const NamedTensors & in)339 Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
340   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
341   std::vector<string> keys;
342   std::vector<Tensor> tensors_to_send;
343   keys.reserve(in.size());
344   tensors_to_send.reserve(in.size());
345   for (const auto& p : in) {
346     keys.push_back(p.first);
347     tensors_to_send.push_back(p.second);
348   }
349   Status s =
350       SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
351   rendezvous->Unref();
352   return s;
353 }
354 
RecvOutputs(const int64 step_id,NamedTensors * out)355 Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
356   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
357   Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
358   rendezvous->Unref();
359   if (!s.ok()) {
360     // Failing to fetch the outputs should not be possible, so rewrite the error
361     // status to an INTERNAL error.
362     s = errors::Internal("Failed to fetch outputs for step ", step_id,
363                          ". (Original error message: ", s.ToString(), ")");
364   }
365   return s;
366 }
367 
RecvOutputsAsync(const int64 step_id,NamedTensors * out,StatusCallback done)368 void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
369                                 StatusCallback done) {
370   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
371   std::vector<string> keys;
372   std::vector<Tensor>* received_keys = new std::vector<Tensor>;
373   keys.reserve(out->size());
374   received_keys->reserve(out->size());
375   for (const auto& p : *out) {
376     keys.push_back(p.first);
377     received_keys->push_back(p.second);
378   }
379   RecvOutputsFromRendezvousAsync(
380       rendezvous, nullptr, {}, keys, received_keys,
381       [done, rendezvous, received_keys, out, keys](const Status s) {
382         rendezvous->Unref();
383         for (int i = 0; i < keys.size(); ++i) {
384           (*out)[keys[i]] = (*received_keys)[i];
385         }
386         delete received_keys;
387         done(s);
388       });
389 }
390 
ExecuteAsync(const string & handle,const int64 step_id,WorkerSession * session,const ExecutorOpts & opts,StepStatsCollector * collector,MutableRunGraphResponseWrapper * response,CancellationManager * cancellation_manager,const NamedTensors & in,StatusCallback done)391 void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
392                             WorkerSession* session, const ExecutorOpts& opts,
393                             StepStatsCollector* collector,
394                             MutableRunGraphResponseWrapper* response,
395                             CancellationManager* cancellation_manager,
396                             const NamedTensors& in, StatusCallback done) {
397   const uint64 start_time_usecs = Env::Default()->NowMicros();
398   // Lookup an item. Holds one ref while executing.
399   Item* item = nullptr;
400   {
401     mutex_lock l(mu_);
402     auto iter = table_.find(handle);
403     if (iter != table_.end()) {
404       item = iter->second;
405       item->Ref();
406     }
407   }
408 
409   if (item == nullptr) {
410     done(errors::Aborted("Graph handle is not found: ", handle));
411     return;
412   }
413 
414   CostGraphDef* cost_graph = nullptr;
415   if (response != nullptr) {
416     cost_graph = response->mutable_cost_graph();
417     if (opts.record_partition_graphs()) {
418       for (const ExecutionUnit& unit : item->units) {
419         GraphDef graph_def;
420         unit.graph->ToGraphDef(&graph_def);
421         response->AddPartitionGraph(graph_def);
422       }
423     }
424   }
425 
426   RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
427   Status s = rendezvous->Initialize(session);
428   CollectiveExecutor::Handle* ce_handle =
429       item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
430           ? new CollectiveExecutor::Handle(
431                 worker_env_->collective_executor_mgr->FindOrCreate(step_id),
432                 true)
433           : nullptr;
434   // Sends values specified by the caller.
435   if (s.ok()) {
436     std::vector<string> keys;
437     std::vector<Tensor> tensors_to_send;
438     keys.reserve(in.size());
439     tensors_to_send.reserve(in.size());
440     for (auto& p : in) {
441       keys.push_back(p.first);
442       tensors_to_send.push_back(p.second);
443     }
444     s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
445   }
446 
447   if (!s.ok()) {
448     done(s);
449     delete ce_handle;
450     item->Unref();
451     rendezvous->Unref();
452     return;
453   }
454 
455   StartParallelExecutors(
456       handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
457       cancellation_manager,
458       [item, rendezvous, ce_handle, done, start_time_usecs](const Status& s) {
459         done(s);
460         metrics::UpdateGraphExecTime(Env::Default()->NowMicros() -
461                                      start_time_usecs);
462         rendezvous->Unref();
463         item->Unref();
464         delete ce_handle;
465       });
466 }
467 
StartParallelExecutors(const string & handle,int64 step_id,Item * item,Rendezvous * rendezvous,CollectiveExecutor::Handle * ce_handle,StepStatsCollector * collector,CostGraphDef * cost_graph,CancellationManager * cancellation_manager,StatusCallback done)468 void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
469                                       Item* item, Rendezvous* rendezvous,
470                                       CollectiveExecutor::Handle* ce_handle,
471                                       StepStatsCollector* collector,
472                                       CostGraphDef* cost_graph,
473                                       CancellationManager* cancellation_manager,
474                                       StatusCallback done) {
475   const int num_units = item->units.size();
476   CHECK_GE(num_units, 1);
477   ScopedStepContainer* step_container = new ScopedStepContainer(
478       step_id,
479       [this](const string& name) { device_mgr_->ClearContainers({name}); });
480   // NOTE: Transfer one ref of rendezvous and item.
481   ExecutorBarrier* barrier =
482       new ExecutorBarrier(num_units, rendezvous,
483                           [this, item, collector, cost_graph, step_container,
484                            done](const Status& s) {
485                             BuildCostModel(item, collector, cost_graph);
486                             done(s);
487                             delete step_container;
488                           });
489   Executor::Args args;
490   args.step_id = step_id;
491   args.rendezvous = rendezvous;
492   args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
493   args.cancellation_manager = cancellation_manager;
494   args.stats_collector = collector;
495   args.step_container = step_container;
496   args.sync_on_finish = sync_on_finish_;
497   if (LogMemory::IsEnabled()) {
498     LogMemory::RecordStep(args.step_id, handle);
499   }
500   thread::ThreadPool* pool = worker_env_->compute_pool;
501   using std::placeholders::_1;
502   // Line below is equivalent to this code, but does one less indirect call:
503   //  args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
504   auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
505   for (const auto& unit : item->units) {
506     // TODO(zhengxq): if the device picks its own threadpool, we need to assign
507     //     less threads to the main compute pool by default.
508     thread::ThreadPool* device_thread_pool =
509         unit.device->tensorflow_device_thread_pool();
510     if (!device_thread_pool) {
511       args.runner = default_runner;
512     } else {
513       args.runner =
514           std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1);
515     }
516     unit.root->RunAsync(args, barrier->Get());
517   }
518 }
519 
BuildCostModel(Item * item,StepStatsCollector * collector,CostGraphDef * cost_graph)520 void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
521                               CostGraphDef* cost_graph) {
522   if (collector && !skip_cost_models_) {
523     // Build the cost model
524     std::unordered_map<string, const Graph*> device_to_graph;
525     for (const auto& unit : item->units) {
526       if (unit.build_cost_model > 0) {
527         device_to_graph[unit.device->name()] = unit.graph;
528       }
529     }
530     collector->BuildCostModel(&cost_model_manager_, device_to_graph);
531 
532     if (cost_graph != nullptr) {
533       for (const auto& unit : item->units) {
534         cost_model_manager_.AddToCostGraphDef(unit.graph, cost_graph)
535             .IgnoreError();
536       }
537     }
538   }
539 }
540 
541 }  // end namespace tensorflow
542