• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
17 
18 #include <chrono>  // NOLINT(build/c++11)
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/build_graph_options.h"
22 #include "tensorflow/core/common_runtime/constant_folding.h"
23 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/device_mgr.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/common_runtime/graph_optimizer.h"
29 #include "tensorflow/core/common_runtime/memory_types.h"
30 #include "tensorflow/core/common_runtime/metrics.h"
31 #include "tensorflow/core/common_runtime/optimization_registry.h"
32 #include "tensorflow/core/common_runtime/process_util.h"
33 #include "tensorflow/core/common_runtime/rendezvous_util.h"
34 #include "tensorflow/core/common_runtime/step_stats_collector.h"
35 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
36 #include "tensorflow/core/framework/cancellation.h"
37 #include "tensorflow/core/framework/collective.h"
38 #include "tensorflow/core/framework/log_memory.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/versions.pb.h"
42 #include "tensorflow/core/graph/graph.h"
43 #include "tensorflow/core/graph/graph_partition.h"
44 #include "tensorflow/core/graph/validate.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/strings/stringprintf.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/tracing.h"
51 #include "tensorflow/core/platform/types.h"
52 #include "tensorflow/core/profiler/lib/connected_traceme.h"
53 #include "tensorflow/core/profiler/lib/traceme_encode.h"
54 #include "tensorflow/core/protobuf/worker.pb.h"
55 #include "tensorflow/core/util/env_var.h"
56 
57 namespace tensorflow {
58 
GraphMgr(const WorkerEnv * worker_env,const DeviceMgr * device_mgr)59 GraphMgr::GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr)
60     : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
61   // The default value of sync_on_finish will be flipped soon and this
62   // environment variable will be removed as well.
63   Status status =
64       ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
65   if (!status.ok()) {
66     LOG(ERROR) << status.error_message();
67   }
68 }
69 
~GraphMgr()70 GraphMgr::~GraphMgr() {
71   for (const auto& p : table_) p.second->Unref();
72 }
73 
~Item()74 GraphMgr::Item::~Item() {
75   for (const auto& unit : this->units) {
76     CHECK_NOTNULL(unit.device);
77     if (!graph_mgr->skip_cost_models_) {
78       graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph.get());
79     }
80     delete unit.root;
81     unit.device->op_segment()->RemoveHold(this->session);
82   }
83 }
84 
85 // NOTE: node->device_name() is not set by GraphConstructor.  We
86 // expects that NodeDef in GraphDef given to workers fully specifies
87 // device names.
SplitByDevice(const Node * node)88 static string SplitByDevice(const Node* node) {
89   return node->assigned_device_name();
90 }
91 
92 // Validates "gdef" device specifications.
ValidateGraphDefForDevices(const GraphDef & gdef)93 static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
94   DeviceNameUtils::ParsedName parsed;
95   for (const auto& ndef : gdef.node()) {
96     if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) {
97       return errors::InvalidArgument("Missing device name in: ",
98                                      FormatNodeDefForError(ndef));
99     }
100   }
101   return Status::OK();
102 }
103 
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)104 Status GraphMgr::DecorateAndPublishGraphForDebug(
105     const DebugOptions& debug_options, Graph* graph, Device* device) {
106   std::unique_ptr<DebugGraphDecoratorInterface> decorator;
107   TF_RETURN_IF_ERROR(
108       DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
109   TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
110   TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
111   return Status::OK();
112 }
113 
114 // Creates executors given a graph definition "gdef" of a "session".
115 // If a node in "gdef" is shared by other graphs in "session", the
116 // same op kernel is reused. E.g., typically a params node is shared
117 // by multiple graphs in a session.
118 //
119 // If "gdef" is assigned to multiple devices, extra nodes (e.g.,
120 // send/recv nodes) maybe added. The extra nodes' name are generated
121 // by calling "new_name(old_name)".
122 //
123 // "executors" are filled with one executor per device if success and
124 // the caller takes the ownership of returned executors.
InitItem(const string & handle,const GraphDef & gdef,WorkerSession * session,const GraphOptions & graph_options,const DebugOptions & debug_options,const ConfigProto & config_proto,int64 collective_graph_key,DistributedFunctionLibraryRuntime * cluster_flr,Item * item)125 Status GraphMgr::InitItem(
126     const string& handle, const GraphDef& gdef, WorkerSession* session,
127     const GraphOptions& graph_options, const DebugOptions& debug_options,
128     const ConfigProto& config_proto, int64 collective_graph_key,
129     DistributedFunctionLibraryRuntime* cluster_flr, Item* item) {
130   item->session = handle;
131   item->collective_graph_key = collective_graph_key;
132   item->lib_def.reset(
133       new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
134 
135   TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
136 
137   // We don't explicitly Validate the graph def because ConvertGraphDefToGraph
138   // does that below.
139   item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
140       device_mgr_, worker_env_->env, /*config=*/&config_proto,
141       gdef.versions().producer(), item->lib_def.get(),
142       graph_options.optimizer_options(), worker_env_->compute_pool, cluster_flr,
143       /*session_metadata=*/nullptr,
144       Rendezvous::Factory{
145           [this, session](const int64 step_id, const DeviceMgr*,
146                           Rendezvous** r) -> Status {
147             auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id);
148             TF_RETURN_IF_ERROR(remote_r->Initialize(session));
149             *r = remote_r;
150             return Status::OK();
151           },
152           [this](const int64 step_id) {
153             this->worker_env_->rendezvous_mgr->Cleanup(step_id);
154             return Status::OK();
155           }}));
156 
157   // Constructs the graph out of "gdef".
158   Graph graph(OpRegistry::Global());
159   GraphConstructorOptions opts;
160   opts.allow_internal_ops = true;
161   opts.expect_device_spec = true;
162   opts.validate_nodes = true;
163   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
164 
165   // Splits "graph" into multiple subgraphs by device names.
166   std::unordered_map<string, GraphDef> partitions;
167   PartitionOptions popts;
168   popts.node_to_loc = SplitByDevice;
169   popts.new_name = [this](const string& prefix) {
170     mutex_lock l(mu_);
171     return strings::StrCat(prefix, "_G", next_id_++);
172   };
173   popts.get_incarnation = [this](const string& name) -> int64 {
174     Device* device = nullptr;
175     Status s = device_mgr_->LookupDevice(name, &device);
176     if (s.ok()) {
177       return device->attributes().incarnation();
178     } else {
179       return PartitionOptions::kIllegalIncarnation;
180     }
181   };
182   popts.flib_def = &graph.flib_def();
183   popts.control_flow_added = true;
184   popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
185   TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
186   if (popts.scheduling_for_recvs) {
187     TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
188   }
189 
190   std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
191   for (auto& partition : partitions) {
192     std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
193     GraphConstructorOptions device_opts;
194     // There are internal operations (e.g., send/recv) that we now allow.
195     device_opts.allow_internal_ops = true;
196     device_opts.expect_device_spec = true;
197     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
198         device_opts, std::move(partition.second), device_graph.get()));
199     partition_graphs.emplace(partition.first, std::move(device_graph));
200   }
201 
202   GraphOptimizationPassOptions optimization_options;
203   optimization_options.flib_def = item->lib_def.get();
204   optimization_options.partition_graphs = &partition_graphs;
205   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
206       OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
207 
208   LocalExecutorParams params;
209 
210   item->units.reserve(partitions.size());
211   item->graph_mgr = this;
212   const auto& optimizer_opts = graph_options.optimizer_options();
213   GraphOptimizer optimizer(optimizer_opts);
214   for (auto& p : partition_graphs) {
215     const string& device_name = p.first;
216     std::unique_ptr<Graph>& subgraph = p.second;
217     item->units.resize(item->units.size() + 1);
218     ExecutionUnit* unit = &(item->units.back());
219 
220     // Find the device.
221     Status s = device_mgr_->LookupDevice(device_name, &unit->device);
222     if (!s.ok()) {
223       // Remove the empty unit from the item as the item destructor wants all
224       // units to have valid devices.
225       item->units.pop_back();
226       return s;
227     }
228 
229     // Give the device an opportunity to rewrite its subgraph.
230     TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph));
231 
232     // Top-level nodes in the graph uses the op segment to cache
233     // kernels. Therefore, as long as the executor is alive, we need
234     // to ensure the kernels cached for the session are alive.
235     auto opseg = unit->device->op_segment();
236     opseg->AddHold(handle);
237 
238     // Function library runtime.
239     FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name());
240     if (lib == nullptr) {
241       return errors::InvalidArgument("Cannot find FLR for device: ",
242                                      unit->device->name());
243     }
244 
245     // Construct the root executor for the subgraph.
246     params.device = unit->device;
247     params.function_library = lib;
248     params.create_kernel =
249         [handle, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
250                              OpKernel** kernel) {
251           // NOTE(mrry): We must not share function kernels (implemented
252           // using `CallOp`) between subgraphs, because `CallOp::handle_`
253           // is tied to a particular subgraph. Even if the function itself
254           // is stateful, the `CallOp` that invokes it is not.
255           if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
256             return lib->CreateKernel(props, kernel);
257           }
258           auto create_fn = [lib, &props](OpKernel** kernel) {
259             return lib->CreateKernel(props, kernel);
260           };
261           // Kernels created for subgraph nodes need to be cached.  On
262           // cache miss, create_fn() is invoked to create a kernel based
263           // on the function library here + global op registry.
264           return opseg->FindOrCreate(handle, props->node_def.name(), kernel,
265                                      create_fn);
266         };
267     params.delete_kernel = [lib](OpKernel* kernel) {
268       if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
269         delete kernel;
270       }
271     };
272 
273     optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
274                        /*shape_map=*/nullptr);
275 
276     // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
277     if (!debug_options.debug_tensor_watch_opts().empty()) {
278       TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
279           debug_options, subgraph.get(), params.device));
280     }
281 
282     TF_RETURN_IF_ERROR(
283         EnsureMemoryTypes(DeviceType(unit->device->device_type()),
284                           unit->device->name(), subgraph.get()));
285     unit->graph = std::move(subgraph);
286     unit->build_cost_model = graph_options.build_cost_model();
287     if (unit->build_cost_model > 0) {
288       skip_cost_models_ = false;
289     }
290     TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root));
291   }
292   return Status::OK();
293 }
294 
Register(const string & handle,const GraphDef & gdef,WorkerSession * session,const GraphOptions & graph_options,const DebugOptions & debug_options,const ConfigProto & config_proto,int64 collective_graph_key,DistributedFunctionLibraryRuntime * cluster_flr,string * graph_handle)295 Status GraphMgr::Register(
296     const string& handle, const GraphDef& gdef, WorkerSession* session,
297     const GraphOptions& graph_options, const DebugOptions& debug_options,
298     const ConfigProto& config_proto, int64 collective_graph_key,
299     DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
300   Item* item = new Item;
301   Status s = InitItem(handle, gdef, session, graph_options, debug_options,
302                       config_proto, collective_graph_key, cluster_flr, item);
303   if (!s.ok()) {
304     item->Unref();
305     return s;
306   }
307 
308   // Inserts one item into table_.
309   {
310     mutex_lock l(mu_);
311     *graph_handle =
312         strings::Printf("%016llx", static_cast<long long>(++next_id_));
313     item->handle = *graph_handle;
314     CHECK(table_.insert({*graph_handle, item}).second);
315   }
316   return Status::OK();
317 }
318 
Deregister(const string & handle)319 Status GraphMgr::Deregister(const string& handle) {
320   Item* item = nullptr;
321   // Removes one item from table_.
322   {
323     mutex_lock l(mu_);
324     auto iter = table_.find(handle);
325     if (iter == table_.end()) {
326       return errors::Aborted("Graph handle is not found: ", handle,
327                              ". Possibly, this worker just restarted.");
328     }
329     item = iter->second;
330     table_.erase(iter);
331   }
332   item->Unref();
333   return Status::OK();
334 }
335 
DeregisterAll()336 Status GraphMgr::DeregisterAll() {
337   std::vector<Item*> items;
338   // Removes all items from table_.
339   {
340     mutex_lock l(mu_);
341     for (const auto& entry : table_) {
342       items.push_back(entry.second);
343     }
344     table_.clear();
345   }
346   for (auto item : items) {
347     item->Unref();
348   }
349   return Status::OK();
350 }
351 
SendInputs(const int64 step_id,const NamedTensors & in)352 Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
353   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
354   std::vector<string> keys;
355   std::vector<Tensor> tensors_to_send;
356   keys.reserve(in.size());
357   tensors_to_send.reserve(in.size());
358   size_t input_size = 0;
359   for (const auto& p : in) {
360     keys.push_back(p.first);
361     tensors_to_send.push_back(p.second);
362     input_size += p.second.AllocatedBytes();
363   }
364   metrics::RecordGraphInputTensors(input_size);
365   Status s =
366       SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
367   rendezvous->Unref();
368   return s;
369 }
370 
RecvOutputs(const int64 step_id,NamedTensors * out)371 Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
372   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
373   Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
374   rendezvous->Unref();
375   if (!s.ok()) {
376     // Failing to fetch the outputs should not be possible, so rewrite the error
377     // status to an INTERNAL error.
378     s = errors::Internal("Failed to fetch outputs for step ", step_id,
379                          ". (Original error message: ", s.ToString(), ")");
380   }
381   size_t output_size = 0;
382   for (auto& p : *out) {
383     output_size += p.second.AllocatedBytes();
384   }
385   metrics::RecordGraphOutputTensors(output_size);
386   return s;
387 }
388 
RecvOutputsAsync(const int64 step_id,NamedTensors * out,StatusCallback done)389 void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
390                                 StatusCallback done) {
391   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
392   std::vector<string> keys;
393   std::vector<Tensor>* received_keys = new std::vector<Tensor>;
394   keys.reserve(out->size());
395   received_keys->reserve(out->size());
396   for (const auto& p : *out) {
397     keys.push_back(p.first);
398     received_keys->push_back(p.second);
399   }
400   RecvOutputsFromRendezvousAsync(
401       rendezvous, nullptr, {}, keys, received_keys,
402       [done, rendezvous, received_keys, out, keys](const Status s) {
403         rendezvous->Unref();
404         size_t output_size = 0;
405         for (int i = 0, end = keys.size(); i < end; ++i) {
406           (*out)[keys[i]] = (*received_keys)[i];
407           output_size += (*out)[keys[i]].AllocatedBytes();
408         }
409         metrics::RecordGraphOutputTensors(output_size);
410         delete received_keys;
411         done(s);
412       });
413 }
414 
ExecuteAsync(const string & handle,const int64 step_id,WorkerSession * session,const ExecutorOpts & opts,StepStatsCollector * collector,MutableRunGraphResponseWrapper * response,CancellationManager * cancellation_manager,const NamedTensors & in,StatusCallback done)415 void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
416                             WorkerSession* session, const ExecutorOpts& opts,
417                             StepStatsCollector* collector,
418                             MutableRunGraphResponseWrapper* response,
419                             CancellationManager* cancellation_manager,
420                             const NamedTensors& in, StatusCallback done) {
421   const uint64 start_time_usecs = Env::Default()->NowMicros();
422   profiler::TraceMeProducer activity(
423       // To TraceMeConsumers in ExecutorState::Process/Finish or RunGraphDone.
424       [step_id] {
425         return profiler::TraceMeEncode(
426             "RunGraph", {{"id", step_id}, {"_r", 1} /*root_event*/});
427       },
428       profiler::ContextType::kTfExecutor, step_id,
429       profiler::TraceMeLevel::kInfo);
430   // Lookup an item. Holds one ref while executing.
431   Item* item = nullptr;
432   {
433     mutex_lock l(mu_);
434     auto iter = table_.find(handle);
435     if (iter != table_.end()) {
436       item = iter->second;
437       item->Ref();
438     }
439   }
440 
441   if (item == nullptr) {
442     done(errors::Aborted("Graph handle is not found: ", handle));
443     return;
444   }
445 
446   CostGraphDef* cost_graph = nullptr;
447   if (response != nullptr) {
448     cost_graph = response->mutable_cost_graph();
449     if (opts.record_partition_graphs()) {
450       for (const ExecutionUnit& unit : item->units) {
451         GraphDef graph_def;
452         unit.graph->ToGraphDef(&graph_def);
453         response->AddPartitionGraph(graph_def);
454       }
455     }
456   }
457 
458   RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
459   Status s = rendezvous->Initialize(session);
460   CollectiveExecutor::Handle* ce_handle =
461       item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
462           ? new CollectiveExecutor::Handle(
463                 worker_env_->collective_executor_mgr->FindOrCreate(step_id),
464                 true)
465           : nullptr;
466   // Sends values specified by the caller.
467   size_t input_size = 0;
468   if (s.ok()) {
469     std::vector<string> keys;
470     std::vector<Tensor> tensors_to_send;
471     keys.reserve(in.size());
472     tensors_to_send.reserve(in.size());
473     for (auto& p : in) {
474       keys.push_back(p.first);
475       tensors_to_send.push_back(p.second);
476       input_size += p.second.AllocatedBytes();
477     }
478     s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
479   }
480 
481   if (!s.ok()) {
482     done(s);
483     delete ce_handle;
484     item->Unref();
485     rendezvous->Unref();
486     return;
487   }
488 
489   StartParallelExecutors(
490       handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
491       cancellation_manager, session,
492       [item, rendezvous, ce_handle, done, start_time_usecs, input_size,
493        step_id](const Status& s) {
494         profiler::TraceMeConsumer activity(
495             // From TraceMeProducer in GraphMgr::ExecuteAsync.
496             [step_id] {
497               return profiler::TraceMeEncode("RunGraphDone", {{"id", step_id}});
498             },
499             profiler::ContextType::kTfExecutor, step_id,
500             profiler::TraceMeLevel::kInfo);
501         done(s);
502         metrics::RecordGraphInputTensors(input_size);
503         metrics::UpdateGraphExecTime(Env::Default()->NowMicros() -
504                                      start_time_usecs);
505         rendezvous->Unref();
506         item->Unref();
507         delete ce_handle;
508       });
509 }
510 
StartParallelExecutors(const string & handle,int64 step_id,Item * item,Rendezvous * rendezvous,CollectiveExecutor::Handle * ce_handle,StepStatsCollector * collector,CostGraphDef * cost_graph,CancellationManager * cancellation_manager,WorkerSession * session,StatusCallback done)511 void GraphMgr::StartParallelExecutors(
512     const string& handle, int64 step_id, Item* item, Rendezvous* rendezvous,
513     CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector,
514     CostGraphDef* cost_graph, CancellationManager* cancellation_manager,
515     WorkerSession* session, StatusCallback done) {
516   const int num_units = item->units.size();
517   CHECK_GE(num_units, 1);
518   ScopedStepContainer* step_container = new ScopedStepContainer(
519       step_id,
520       [this](const string& name) { device_mgr_->ClearContainers({name}); });
521   // NOTE: Transfer one ref of rendezvous and item.
522   ExecutorBarrier* barrier =
523       new ExecutorBarrier(num_units, rendezvous,
524                           [this, item, collector, cost_graph, step_container,
525                            done](const Status& s) {
526                             BuildCostModel(item, collector, cost_graph);
527                             done(s);
528                             delete step_container;
529                           });
530   Executor::Args args;
531   args.step_id = step_id;
532   args.rendezvous = rendezvous;
533   args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
534   args.cancellation_manager = cancellation_manager;
535   args.stats_collector = collector;
536   args.step_container = step_container;
537   args.sync_on_finish = sync_on_finish_;
538   if (LogMemory::IsEnabled()) {
539     LogMemory::RecordStep(args.step_id, handle);
540   }
541   thread::ThreadPool* pool = worker_env_->compute_pool;
542   using std::placeholders::_1;
543   // Line below is equivalent to this code, but does one less indirect call:
544   //  args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
545   auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
546   for (const auto& unit : item->units) {
547     // TODO(zhengxq): if the device picks its own threadpool, we need to assign
548     //     less threads to the main compute pool by default.
549     thread::ThreadPool* device_thread_pool =
550         unit.device->tensorflow_device_thread_pool();
551     if (!device_thread_pool) {
552       args.runner = default_runner;
553     } else {
554       args.runner =
555           std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1);
556     }
557     unit.root->RunAsync(args, barrier->Get());
558   }
559 }
560 
BuildCostModel(Item * item,StepStatsCollector * collector,CostGraphDef * cost_graph)561 void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
562                               CostGraphDef* cost_graph) {
563   if (collector && !skip_cost_models_) {
564     // Build the cost model
565     std::unordered_map<string, const Graph*> device_to_graph;
566     for (const auto& unit : item->units) {
567       if (unit.build_cost_model > 0) {
568         device_to_graph[unit.device->name()] = unit.graph.get();
569       }
570     }
571     collector->BuildCostModel(&cost_model_manager_, device_to_graph);
572 
573     if (cost_graph != nullptr) {
574       for (const auto& unit : item->units) {
575         cost_model_manager_.AddToCostGraphDef(unit.graph.get(), cost_graph)
576             .IgnoreError();
577       }
578     }
579   }
580 }
581 
582 }  // end namespace tensorflow
583