1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
17
18 #include <chrono> // NOLINT(build/c++11)
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/build_graph_options.h"
22 #include "tensorflow/core/common_runtime/constant_folding.h"
23 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/device_mgr.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/common_runtime/graph_optimizer.h"
29 #include "tensorflow/core/common_runtime/memory_types.h"
30 #include "tensorflow/core/common_runtime/metrics.h"
31 #include "tensorflow/core/common_runtime/optimization_registry.h"
32 #include "tensorflow/core/common_runtime/process_util.h"
33 #include "tensorflow/core/common_runtime/rendezvous_util.h"
34 #include "tensorflow/core/common_runtime/step_stats_collector.h"
35 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
36 #include "tensorflow/core/framework/cancellation.h"
37 #include "tensorflow/core/framework/collective.h"
38 #include "tensorflow/core/framework/log_memory.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/versions.pb.h"
42 #include "tensorflow/core/graph/graph.h"
43 #include "tensorflow/core/graph/graph_partition.h"
44 #include "tensorflow/core/graph/validate.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/strings/stringprintf.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/tracing.h"
51 #include "tensorflow/core/platform/types.h"
52 #include "tensorflow/core/profiler/lib/connected_traceme.h"
53 #include "tensorflow/core/profiler/lib/traceme_encode.h"
54 #include "tensorflow/core/protobuf/worker.pb.h"
55 #include "tensorflow/core/util/env_var.h"
56
57 namespace tensorflow {
58
GraphMgr(const WorkerEnv * worker_env,const DeviceMgr * device_mgr)59 GraphMgr::GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr)
60 : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
61 // The default value of sync_on_finish will be flipped soon and this
62 // environment variable will be removed as well.
63 Status status =
64 ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
65 if (!status.ok()) {
66 LOG(ERROR) << status.error_message();
67 }
68 }
69
~GraphMgr()70 GraphMgr::~GraphMgr() {
71 for (const auto& p : table_) p.second->Unref();
72 }
73
~Item()74 GraphMgr::Item::~Item() {
75 for (const auto& unit : this->units) {
76 CHECK_NOTNULL(unit.device);
77 if (!graph_mgr->skip_cost_models_) {
78 graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph.get());
79 }
80 delete unit.root;
81 unit.device->op_segment()->RemoveHold(this->session);
82 }
83 }
84
85 // NOTE: node->device_name() is not set by GraphConstructor. We
86 // expects that NodeDef in GraphDef given to workers fully specifies
87 // device names.
SplitByDevice(const Node * node)88 static string SplitByDevice(const Node* node) {
89 return node->assigned_device_name();
90 }
91
92 // Validates "gdef" device specifications.
ValidateGraphDefForDevices(const GraphDef & gdef)93 static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
94 DeviceNameUtils::ParsedName parsed;
95 for (const auto& ndef : gdef.node()) {
96 if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) {
97 return errors::InvalidArgument("Missing device name in: ",
98 FormatNodeDefForError(ndef));
99 }
100 }
101 return Status::OK();
102 }
103
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)104 Status GraphMgr::DecorateAndPublishGraphForDebug(
105 const DebugOptions& debug_options, Graph* graph, Device* device) {
106 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
107 TF_RETURN_IF_ERROR(
108 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
109 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
110 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
111 return Status::OK();
112 }
113
114 // Creates executors given a graph definition "gdef" of a "session".
115 // If a node in "gdef" is shared by other graphs in "session", the
116 // same op kernel is reused. E.g., typically a params node is shared
117 // by multiple graphs in a session.
118 //
119 // If "gdef" is assigned to multiple devices, extra nodes (e.g.,
120 // send/recv nodes) maybe added. The extra nodes' name are generated
121 // by calling "new_name(old_name)".
122 //
123 // "executors" are filled with one executor per device if success and
124 // the caller takes the ownership of returned executors.
InitItem(const string & handle,const GraphDef & gdef,WorkerSession * session,const GraphOptions & graph_options,const DebugOptions & debug_options,const ConfigProto & config_proto,int64 collective_graph_key,DistributedFunctionLibraryRuntime * cluster_flr,Item * item)125 Status GraphMgr::InitItem(
126 const string& handle, const GraphDef& gdef, WorkerSession* session,
127 const GraphOptions& graph_options, const DebugOptions& debug_options,
128 const ConfigProto& config_proto, int64 collective_graph_key,
129 DistributedFunctionLibraryRuntime* cluster_flr, Item* item) {
130 item->session = handle;
131 item->collective_graph_key = collective_graph_key;
132 item->lib_def.reset(
133 new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
134
135 TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
136
137 // We don't explicitly Validate the graph def because ConvertGraphDefToGraph
138 // does that below.
139 item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
140 device_mgr_, worker_env_->env, /*config=*/&config_proto,
141 gdef.versions().producer(), item->lib_def.get(),
142 graph_options.optimizer_options(), worker_env_->compute_pool, cluster_flr,
143 /*session_metadata=*/nullptr,
144 Rendezvous::Factory{
145 [this, session](const int64 step_id, const DeviceMgr*,
146 Rendezvous** r) -> Status {
147 auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id);
148 TF_RETURN_IF_ERROR(remote_r->Initialize(session));
149 *r = remote_r;
150 return Status::OK();
151 },
152 [this](const int64 step_id) {
153 this->worker_env_->rendezvous_mgr->Cleanup(step_id);
154 return Status::OK();
155 }}));
156
157 // Constructs the graph out of "gdef".
158 Graph graph(OpRegistry::Global());
159 GraphConstructorOptions opts;
160 opts.allow_internal_ops = true;
161 opts.expect_device_spec = true;
162 opts.validate_nodes = true;
163 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
164
165 // Splits "graph" into multiple subgraphs by device names.
166 std::unordered_map<string, GraphDef> partitions;
167 PartitionOptions popts;
168 popts.node_to_loc = SplitByDevice;
169 popts.new_name = [this](const string& prefix) {
170 mutex_lock l(mu_);
171 return strings::StrCat(prefix, "_G", next_id_++);
172 };
173 popts.get_incarnation = [this](const string& name) -> int64 {
174 Device* device = nullptr;
175 Status s = device_mgr_->LookupDevice(name, &device);
176 if (s.ok()) {
177 return device->attributes().incarnation();
178 } else {
179 return PartitionOptions::kIllegalIncarnation;
180 }
181 };
182 popts.flib_def = &graph.flib_def();
183 popts.control_flow_added = true;
184 popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
185 TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
186 if (popts.scheduling_for_recvs) {
187 TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
188 }
189
190 std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
191 for (auto& partition : partitions) {
192 std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
193 GraphConstructorOptions device_opts;
194 // There are internal operations (e.g., send/recv) that we now allow.
195 device_opts.allow_internal_ops = true;
196 device_opts.expect_device_spec = true;
197 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
198 device_opts, std::move(partition.second), device_graph.get()));
199 partition_graphs.emplace(partition.first, std::move(device_graph));
200 }
201
202 GraphOptimizationPassOptions optimization_options;
203 optimization_options.flib_def = item->lib_def.get();
204 optimization_options.partition_graphs = &partition_graphs;
205 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
206 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
207
208 LocalExecutorParams params;
209
210 item->units.reserve(partitions.size());
211 item->graph_mgr = this;
212 const auto& optimizer_opts = graph_options.optimizer_options();
213 GraphOptimizer optimizer(optimizer_opts);
214 for (auto& p : partition_graphs) {
215 const string& device_name = p.first;
216 std::unique_ptr<Graph>& subgraph = p.second;
217 item->units.resize(item->units.size() + 1);
218 ExecutionUnit* unit = &(item->units.back());
219
220 // Find the device.
221 Status s = device_mgr_->LookupDevice(device_name, &unit->device);
222 if (!s.ok()) {
223 // Remove the empty unit from the item as the item destructor wants all
224 // units to have valid devices.
225 item->units.pop_back();
226 return s;
227 }
228
229 // Give the device an opportunity to rewrite its subgraph.
230 TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph));
231
232 // Top-level nodes in the graph uses the op segment to cache
233 // kernels. Therefore, as long as the executor is alive, we need
234 // to ensure the kernels cached for the session are alive.
235 auto opseg = unit->device->op_segment();
236 opseg->AddHold(handle);
237
238 // Function library runtime.
239 FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name());
240 if (lib == nullptr) {
241 return errors::InvalidArgument("Cannot find FLR for device: ",
242 unit->device->name());
243 }
244
245 // Construct the root executor for the subgraph.
246 params.device = unit->device;
247 params.function_library = lib;
248 params.create_kernel =
249 [handle, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
250 OpKernel** kernel) {
251 // NOTE(mrry): We must not share function kernels (implemented
252 // using `CallOp`) between subgraphs, because `CallOp::handle_`
253 // is tied to a particular subgraph. Even if the function itself
254 // is stateful, the `CallOp` that invokes it is not.
255 if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
256 return lib->CreateKernel(props, kernel);
257 }
258 auto create_fn = [lib, &props](OpKernel** kernel) {
259 return lib->CreateKernel(props, kernel);
260 };
261 // Kernels created for subgraph nodes need to be cached. On
262 // cache miss, create_fn() is invoked to create a kernel based
263 // on the function library here + global op registry.
264 return opseg->FindOrCreate(handle, props->node_def.name(), kernel,
265 create_fn);
266 };
267 params.delete_kernel = [lib](OpKernel* kernel) {
268 if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
269 delete kernel;
270 }
271 };
272
273 optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
274 /*shape_map=*/nullptr);
275
276 // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
277 if (!debug_options.debug_tensor_watch_opts().empty()) {
278 TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
279 debug_options, subgraph.get(), params.device));
280 }
281
282 TF_RETURN_IF_ERROR(
283 EnsureMemoryTypes(DeviceType(unit->device->device_type()),
284 unit->device->name(), subgraph.get()));
285 unit->graph = std::move(subgraph);
286 unit->build_cost_model = graph_options.build_cost_model();
287 if (unit->build_cost_model > 0) {
288 skip_cost_models_ = false;
289 }
290 TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root));
291 }
292 return Status::OK();
293 }
294
Register(const string & handle,const GraphDef & gdef,WorkerSession * session,const GraphOptions & graph_options,const DebugOptions & debug_options,const ConfigProto & config_proto,int64 collective_graph_key,DistributedFunctionLibraryRuntime * cluster_flr,string * graph_handle)295 Status GraphMgr::Register(
296 const string& handle, const GraphDef& gdef, WorkerSession* session,
297 const GraphOptions& graph_options, const DebugOptions& debug_options,
298 const ConfigProto& config_proto, int64 collective_graph_key,
299 DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
300 Item* item = new Item;
301 Status s = InitItem(handle, gdef, session, graph_options, debug_options,
302 config_proto, collective_graph_key, cluster_flr, item);
303 if (!s.ok()) {
304 item->Unref();
305 return s;
306 }
307
308 // Inserts one item into table_.
309 {
310 mutex_lock l(mu_);
311 *graph_handle =
312 strings::Printf("%016llx", static_cast<long long>(++next_id_));
313 item->handle = *graph_handle;
314 CHECK(table_.insert({*graph_handle, item}).second);
315 }
316 return Status::OK();
317 }
318
Deregister(const string & handle)319 Status GraphMgr::Deregister(const string& handle) {
320 Item* item = nullptr;
321 // Removes one item from table_.
322 {
323 mutex_lock l(mu_);
324 auto iter = table_.find(handle);
325 if (iter == table_.end()) {
326 return errors::Aborted("Graph handle is not found: ", handle,
327 ". Possibly, this worker just restarted.");
328 }
329 item = iter->second;
330 table_.erase(iter);
331 }
332 item->Unref();
333 return Status::OK();
334 }
335
DeregisterAll()336 Status GraphMgr::DeregisterAll() {
337 std::vector<Item*> items;
338 // Removes all items from table_.
339 {
340 mutex_lock l(mu_);
341 for (const auto& entry : table_) {
342 items.push_back(entry.second);
343 }
344 table_.clear();
345 }
346 for (auto item : items) {
347 item->Unref();
348 }
349 return Status::OK();
350 }
351
SendInputs(const int64 step_id,const NamedTensors & in)352 Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
353 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
354 std::vector<string> keys;
355 std::vector<Tensor> tensors_to_send;
356 keys.reserve(in.size());
357 tensors_to_send.reserve(in.size());
358 size_t input_size = 0;
359 for (const auto& p : in) {
360 keys.push_back(p.first);
361 tensors_to_send.push_back(p.second);
362 input_size += p.second.AllocatedBytes();
363 }
364 metrics::RecordGraphInputTensors(input_size);
365 Status s =
366 SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
367 rendezvous->Unref();
368 return s;
369 }
370
RecvOutputs(const int64 step_id,NamedTensors * out)371 Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
372 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
373 Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
374 rendezvous->Unref();
375 if (!s.ok()) {
376 // Failing to fetch the outputs should not be possible, so rewrite the error
377 // status to an INTERNAL error.
378 s = errors::Internal("Failed to fetch outputs for step ", step_id,
379 ". (Original error message: ", s.ToString(), ")");
380 }
381 size_t output_size = 0;
382 for (auto& p : *out) {
383 output_size += p.second.AllocatedBytes();
384 }
385 metrics::RecordGraphOutputTensors(output_size);
386 return s;
387 }
388
RecvOutputsAsync(const int64 step_id,NamedTensors * out,StatusCallback done)389 void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
390 StatusCallback done) {
391 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
392 std::vector<string> keys;
393 std::vector<Tensor>* received_keys = new std::vector<Tensor>;
394 keys.reserve(out->size());
395 received_keys->reserve(out->size());
396 for (const auto& p : *out) {
397 keys.push_back(p.first);
398 received_keys->push_back(p.second);
399 }
400 RecvOutputsFromRendezvousAsync(
401 rendezvous, nullptr, {}, keys, received_keys,
402 [done, rendezvous, received_keys, out, keys](const Status s) {
403 rendezvous->Unref();
404 size_t output_size = 0;
405 for (int i = 0, end = keys.size(); i < end; ++i) {
406 (*out)[keys[i]] = (*received_keys)[i];
407 output_size += (*out)[keys[i]].AllocatedBytes();
408 }
409 metrics::RecordGraphOutputTensors(output_size);
410 delete received_keys;
411 done(s);
412 });
413 }
414
ExecuteAsync(const string & handle,const int64 step_id,WorkerSession * session,const ExecutorOpts & opts,StepStatsCollector * collector,MutableRunGraphResponseWrapper * response,CancellationManager * cancellation_manager,const NamedTensors & in,StatusCallback done)415 void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
416 WorkerSession* session, const ExecutorOpts& opts,
417 StepStatsCollector* collector,
418 MutableRunGraphResponseWrapper* response,
419 CancellationManager* cancellation_manager,
420 const NamedTensors& in, StatusCallback done) {
421 const uint64 start_time_usecs = Env::Default()->NowMicros();
422 profiler::TraceMeProducer activity(
423 // To TraceMeConsumers in ExecutorState::Process/Finish or RunGraphDone.
424 [step_id] {
425 return profiler::TraceMeEncode(
426 "RunGraph", {{"id", step_id}, {"_r", 1} /*root_event*/});
427 },
428 profiler::ContextType::kTfExecutor, step_id,
429 profiler::TraceMeLevel::kInfo);
430 // Lookup an item. Holds one ref while executing.
431 Item* item = nullptr;
432 {
433 mutex_lock l(mu_);
434 auto iter = table_.find(handle);
435 if (iter != table_.end()) {
436 item = iter->second;
437 item->Ref();
438 }
439 }
440
441 if (item == nullptr) {
442 done(errors::Aborted("Graph handle is not found: ", handle));
443 return;
444 }
445
446 CostGraphDef* cost_graph = nullptr;
447 if (response != nullptr) {
448 cost_graph = response->mutable_cost_graph();
449 if (opts.record_partition_graphs()) {
450 for (const ExecutionUnit& unit : item->units) {
451 GraphDef graph_def;
452 unit.graph->ToGraphDef(&graph_def);
453 response->AddPartitionGraph(graph_def);
454 }
455 }
456 }
457
458 RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
459 Status s = rendezvous->Initialize(session);
460 CollectiveExecutor::Handle* ce_handle =
461 item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
462 ? new CollectiveExecutor::Handle(
463 worker_env_->collective_executor_mgr->FindOrCreate(step_id),
464 true)
465 : nullptr;
466 // Sends values specified by the caller.
467 size_t input_size = 0;
468 if (s.ok()) {
469 std::vector<string> keys;
470 std::vector<Tensor> tensors_to_send;
471 keys.reserve(in.size());
472 tensors_to_send.reserve(in.size());
473 for (auto& p : in) {
474 keys.push_back(p.first);
475 tensors_to_send.push_back(p.second);
476 input_size += p.second.AllocatedBytes();
477 }
478 s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
479 }
480
481 if (!s.ok()) {
482 done(s);
483 delete ce_handle;
484 item->Unref();
485 rendezvous->Unref();
486 return;
487 }
488
489 StartParallelExecutors(
490 handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
491 cancellation_manager, session,
492 [item, rendezvous, ce_handle, done, start_time_usecs, input_size,
493 step_id](const Status& s) {
494 profiler::TraceMeConsumer activity(
495 // From TraceMeProducer in GraphMgr::ExecuteAsync.
496 [step_id] {
497 return profiler::TraceMeEncode("RunGraphDone", {{"id", step_id}});
498 },
499 profiler::ContextType::kTfExecutor, step_id,
500 profiler::TraceMeLevel::kInfo);
501 done(s);
502 metrics::RecordGraphInputTensors(input_size);
503 metrics::UpdateGraphExecTime(Env::Default()->NowMicros() -
504 start_time_usecs);
505 rendezvous->Unref();
506 item->Unref();
507 delete ce_handle;
508 });
509 }
510
StartParallelExecutors(const string & handle,int64 step_id,Item * item,Rendezvous * rendezvous,CollectiveExecutor::Handle * ce_handle,StepStatsCollector * collector,CostGraphDef * cost_graph,CancellationManager * cancellation_manager,WorkerSession * session,StatusCallback done)511 void GraphMgr::StartParallelExecutors(
512 const string& handle, int64 step_id, Item* item, Rendezvous* rendezvous,
513 CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector,
514 CostGraphDef* cost_graph, CancellationManager* cancellation_manager,
515 WorkerSession* session, StatusCallback done) {
516 const int num_units = item->units.size();
517 CHECK_GE(num_units, 1);
518 ScopedStepContainer* step_container = new ScopedStepContainer(
519 step_id,
520 [this](const string& name) { device_mgr_->ClearContainers({name}); });
521 // NOTE: Transfer one ref of rendezvous and item.
522 ExecutorBarrier* barrier =
523 new ExecutorBarrier(num_units, rendezvous,
524 [this, item, collector, cost_graph, step_container,
525 done](const Status& s) {
526 BuildCostModel(item, collector, cost_graph);
527 done(s);
528 delete step_container;
529 });
530 Executor::Args args;
531 args.step_id = step_id;
532 args.rendezvous = rendezvous;
533 args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
534 args.cancellation_manager = cancellation_manager;
535 args.stats_collector = collector;
536 args.step_container = step_container;
537 args.sync_on_finish = sync_on_finish_;
538 if (LogMemory::IsEnabled()) {
539 LogMemory::RecordStep(args.step_id, handle);
540 }
541 thread::ThreadPool* pool = worker_env_->compute_pool;
542 using std::placeholders::_1;
543 // Line below is equivalent to this code, but does one less indirect call:
544 // args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
545 auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
546 for (const auto& unit : item->units) {
547 // TODO(zhengxq): if the device picks its own threadpool, we need to assign
548 // less threads to the main compute pool by default.
549 thread::ThreadPool* device_thread_pool =
550 unit.device->tensorflow_device_thread_pool();
551 if (!device_thread_pool) {
552 args.runner = default_runner;
553 } else {
554 args.runner =
555 std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1);
556 }
557 unit.root->RunAsync(args, barrier->Get());
558 }
559 }
560
BuildCostModel(Item * item,StepStatsCollector * collector,CostGraphDef * cost_graph)561 void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
562 CostGraphDef* cost_graph) {
563 if (collector && !skip_cost_models_) {
564 // Build the cost model
565 std::unordered_map<string, const Graph*> device_to_graph;
566 for (const auto& unit : item->units) {
567 if (unit.build_cost_model > 0) {
568 device_to_graph[unit.device->name()] = unit.graph.get();
569 }
570 }
571 collector->BuildCostModel(&cost_model_manager_, device_to_graph);
572
573 if (cost_graph != nullptr) {
574 for (const auto& unit : item->units) {
575 cost_model_manager_.AddToCostGraphDef(unit.graph.get(), cost_graph)
576 .IgnoreError();
577 }
578 }
579 }
580 }
581
582 } // end namespace tensorflow
583