1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/context.h"
17
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
20 #include "tensorflow/core/common_runtime/device_resolver_local.h"
21 #include "tensorflow/core/common_runtime/device_set.h"
22 #include "tensorflow/core/common_runtime/process_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #ifndef __ANDROID__
25 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
26 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
27 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
28 #endif
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/lib/core/blocking_counter.h"
31 #include "tensorflow/core/util/env_var.h"
32
33 namespace tensorflow {
34 namespace {
35
ReadBoolFromEnvVar(StringPiece env_var_name,bool default_val)36 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
37 bool val;
38 if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
39 return val;
40 }
41 return default_val;
42 }
43
44 } // namespace
45
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_policy,bool async,std::unique_ptr<const DeviceMgr> device_mgr,Rendezvous * rendezvous)46 EagerContext::EagerContext(const SessionOptions& opts,
47 ContextDevicePlacementPolicy default_policy,
48 bool async,
49 std::unique_ptr<const DeviceMgr> device_mgr,
50 Rendezvous* rendezvous)
51 : EagerContext(opts, default_policy, async, device_mgr.release(),
52 /*device_mgr_owned*/ true, rendezvous) {}
53
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_policy,bool async,const DeviceMgr * device_mgr,bool device_mgr_owned,Rendezvous * rendezvous)54 EagerContext::EagerContext(const SessionOptions& opts,
55 ContextDevicePlacementPolicy default_policy,
56 bool async, const DeviceMgr* device_mgr,
57 bool device_mgr_owned, Rendezvous* rendezvous)
58 : policy_(default_policy),
59 devices_(device_mgr->ListDevices()),
60 rendezvous_(rendezvous),
61 thread_pool_(NewThreadPoolFromSessionOptions(opts)),
62 pflr_(new ProcessFunctionLibraryRuntime(
63 device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_,
64 opts.config.graph_options().optimizer_options(), thread_pool_.get())),
65 log_device_placement_(opts.config.log_device_placement()),
66 num_active_steps_(0),
67 async_default_(async),
68 log_memory_(LogMemory::IsEnabled()),
69 env_(opts.env),
70 use_send_tensor_rpc_(false),
71 pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
72 "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
73 if (device_mgr_owned) {
74 local_device_manager_.reset(device_mgr);
75 local_unowned_device_manager_ = nullptr;
76 } else {
77 local_unowned_device_manager_ = device_mgr;
78 }
79 InitDeviceMapAndAsync();
80 runner_ = [this](std::function<void()> closure) {
81 this->thread_pool_->Schedule(std::move(closure));
82 };
83
84 std::unique_ptr<DeviceResolverInterface> drl(
85 new DeviceResolverLocal(local_device_mgr()));
86 std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
87 opts.config, local_device_mgr(), drl.get(),
88 "/job:localhost/replica:0/task:0"));
89 collective_executor_mgr_.reset(new CollectiveExecutorMgr(
90 opts.config, local_device_mgr(), std::move(drl), std::move(cprl)));
91 }
92
InitDeviceMapAndAsync()93 void EagerContext::InitDeviceMapAndAsync() {
94 if (async_default_) {
95 executor_.EnableAsync();
96 }
97
98 for (auto* device : devices_) {
99 devices_map_[device->name()] = device;
100 }
101
102 if (remote_device_manager_ != nullptr) {
103 for (auto* device : remote_device_manager_->ListDevices()) {
104 if (devices_map_.find(device->name()) == devices_map_.end()) {
105 devices_map_[device->name()] = device;
106 devices_.push_back(device);
107 }
108 }
109 }
110
111 DeviceSet ds;
112 for (Device* d : devices_) {
113 ds.AddDevice(d);
114 }
115 prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
116 }
117
Async() const118 bool EagerContext::Async() const {
119 mutex_lock l(async_map_mu_);
120 return gtl::FindWithDefault(thread_local_async_, std::this_thread::get_id(),
121 async_default_);
122 }
123
SetAsyncForThread(bool async)124 Status EagerContext::SetAsyncForThread(bool async) {
125 {
126 tensorflow::mutex_lock l(async_map_mu_);
127 thread_local_async_[std::this_thread::get_id()] = async;
128 }
129 if (async) {
130 executor_.EnableAsync();
131 } else {
132 // TODO(agarwal): Currently we add a wait here to handle cases where a
133 // sync op has a control dependency on an async op, and the latter has not
134 // executed yet. This wait can be removed by storing all the control
135 // inputs and waiting for them when executing ops.
136 return executor_.WaitForAllPendingNodes();
137 }
138 return Status::OK();
139 }
140
ClearCaches()141 Status EagerContext::ClearCaches() {
142 // The executor stores pointers to kernels, so we need to make sure that no
143 // async eager ops are still executing. We lock the cache during this time as
144 // well.
145 mutex_lock ml(cache_mu_);
146 TF_RETURN_IF_ERROR(executor_.WaitForAllPendingNodes());
147 gtl::STLDeleteValues(&kernel_cache_);
148
149 return Status::OK();
150 }
151
SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy)152 void EagerContext::SetThreadLocalDevicePlacementPolicy(
153 ContextDevicePlacementPolicy policy) {
154 mutex_lock ml(policy_map_mu_);
155 thread_local_policies_[std::this_thread::get_id()] = policy;
156 }
157
GetDevicePlacementPolicy()158 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() {
159 mutex_lock ml(policy_map_mu_);
160 auto policy_map_it = thread_local_policies_.find(std::this_thread::get_id());
161 if (policy_map_it != thread_local_policies_.end()) {
162 return policy_map_it->second;
163 }
164 return policy_;
165 }
166
167 #ifndef __ANDROID__
CloseRemoteContexts()168 void EagerContext::CloseRemoteContexts() {
169 // Close all remote contexts.
170 std::vector<eager::CloseContextRequest> requests(remote_contexts_.size());
171 std::vector<eager::CloseContextResponse> responses(remote_contexts_.size());
172 BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
173
174 int i = 0;
175 for (const auto& worker_and_context_id : remote_contexts_) {
176 auto* client =
177 remote_eager_workers_->GetClient(worker_and_context_id.first);
178
179 requests[i].set_context_id(worker_and_context_id.second);
180 client->CloseContextAsync(
181 &requests[i], &responses[i],
182 [&worker_and_context_id, &counter](const Status& s) {
183 if (!s.ok()) {
184 LOG(ERROR) << "Unable to close remote context with ID "
185 << worker_and_context_id.second
186 << " for worker: " << worker_and_context_id.first
187 << " due to " << s.error_message();
188 }
189 counter.DecrementCount();
190 });
191 i++;
192 }
193
194 counter.Wait();
195 }
196 #endif
197
~EagerContext()198 EagerContext::~EagerContext() {
199 #ifndef __ANDROID__
200 if (server_) {
201 // TODO(nareshmodi): Fix this.
202 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
203 "Servers don't support clean shutdown.";
204 server_.release();
205 }
206
207 {
208 mutex_lock l(keep_alive_thread_shutdown_mu_);
209 shutting_down_ = true;
210 keep_alive_thread_cv_.notify_all();
211 }
212 keep_alive_thread_.reset();
213
214 CloseRemoteContexts();
215 #endif
216
217 executor_.WaitForAllPendingNodes().IgnoreError();
218 ClearCaches().IgnoreError();
219 rendezvous_->Unref();
220
221 for (auto& thread : child_threads_) {
222 thread.reset();
223 }
224 }
225
AddChildThread(std::unique_ptr<Thread> thread)226 void EagerContext::AddChildThread(std::unique_ptr<Thread> thread) {
227 child_threads_.push_back(std::move(thread));
228 }
229
FindFunctionByName(const string & name)230 bool EagerContext::FindFunctionByName(const string& name) {
231 mutex_lock l(functions_mu_);
232 return func_lib_def_.Find(name) != nullptr;
233 }
234
FindFunctionOpData(const string & name,const tensorflow::OpRegistrationData ** op_data)235 Status EagerContext::FindFunctionOpData(
236 const string& name, const tensorflow::OpRegistrationData** op_data) {
237 mutex_lock l(functions_mu_);
238 return func_lib_def_.LookUp(name, op_data);
239 }
240
FindFunctionDef(const string & name)241 const FunctionDef* EagerContext::FindFunctionDef(const string& name) {
242 mutex_lock l(functions_mu_);
243 return func_lib_def_.Find(name);
244 }
245
FindDeviceByName(const string & name,Device ** result)246 Status EagerContext::FindDeviceByName(const string& name, Device** result) {
247 auto it = devices_map_.find(name);
248 if (it == devices_map_.end()) {
249 return errors::InvalidArgument(name, " unknown device.");
250 }
251 *result = it->second;
252 return Status::OK();
253 }
254
ClearRunMetadata()255 void EagerContext::ClearRunMetadata() {
256 if (metadata_listener_ != nullptr) {
257 metadata_listener_->BeforeClearRunMetadata();
258 }
259 run_metadata_.Clear();
260 }
261
RegisterRunMetadataListener(RunMetadataListener * listener)262 Status EagerContext::RegisterRunMetadataListener(
263 RunMetadataListener* listener) {
264 mutex_lock l(metadata_mu_);
265 if (metadata_listener_ != nullptr) {
266 return Status(error::Code::INVALID_ARGUMENT,
267 "Cannot run two eager profiler at the same time");
268 }
269 metadata_listener_ = listener;
270 return Status::OK();
271 }
272
ClearRunMetadataListener()273 void EagerContext::ClearRunMetadataListener() {
274 mutex_lock l(metadata_mu_);
275 metadata_listener_ = nullptr;
276 }
277
StartStep()278 void EagerContext::StartStep() {
279 mutex_lock ml(metadata_mu_);
280 num_active_steps_++;
281 if (step_container_ == nullptr) {
282 step_container_.reset(
283 new ScopedStepContainer(0, [this](const string& name) {
284 for (Device* device : devices_) {
285 device->resource_manager()->Cleanup(name).IgnoreError();
286 }
287 }));
288 }
289 }
290
EndStep()291 void EagerContext::EndStep() {
292 mutex_lock ml(metadata_mu_);
293 num_active_steps_--;
294 if (num_active_steps_ == 0) {
295 step_container_.reset();
296 }
297 }
298
StepContainer()299 ScopedStepContainer* EagerContext::StepContainer() {
300 if (num_active_steps_.load() == 0) {
301 return nullptr;
302 }
303 mutex_lock ml(metadata_mu_);
304 return step_container_.get();
305 }
306
MaybeRegisterFunctionRemotely(const FunctionDef & fdef)307 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
308 if (remote_device_manager_ == nullptr) return Status::OK();
309 #ifndef __ANDROID__
310 BlockingCounter blocking_counter(static_cast<int>(remote_contexts_.size()));
311
312 std::vector<eager::RegisterFunctionRequest> requests(remote_contexts_.size());
313 std::vector<eager::RegisterFunctionResponse> responses(
314 remote_contexts_.size());
315 std::vector<Status> statuses(remote_contexts_.size());
316
317 int i = 0;
318 for (const auto& target_and_context_id : remote_contexts_) {
319 requests[i].set_context_id(target_and_context_id.second);
320 *requests[i].mutable_function_def() = fdef;
321
322 auto* eager_client =
323 remote_eager_workers_->GetClient(target_and_context_id.first);
324
325 eager_client->RegisterFunctionAsync(
326 &requests[i], &responses[i],
327 [i, &statuses, &blocking_counter](const Status& status) {
328 statuses[i] = status;
329 blocking_counter.DecrementCount();
330 });
331
332 i++;
333 }
334 blocking_counter.Wait();
335
336 for (int i = 0; i < remote_contexts_.size(); i++) {
337 TF_RETURN_IF_ERROR(statuses[i]);
338 }
339 #endif
340 return Status::OK();
341 }
342
AddFunctionDef(const FunctionDef & fdef)343 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
344 mutex_lock l(functions_mu_);
345 TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef));
346
347 return MaybeRegisterFunctionRemotely(fdef);
348 }
349
GetCachedKernel(Fprint128 cache_key)350 KernelAndDevice* EagerContext::GetCachedKernel(Fprint128 cache_key) {
351 tf_shared_lock l(cache_mu_);
352 return gtl::FindPtrOrNull(kernel_cache_, cache_key);
353 }
354
AddKernelToCache(Fprint128 cache_key,KernelAndDevice * kernel)355 void EagerContext::AddKernelToCache(Fprint128 cache_key,
356 KernelAndDevice* kernel) {
357 mutex_lock ml(cache_mu_);
358 gtl::InsertOrUpdate(&kernel_cache_, cache_key, kernel);
359 }
360
ShouldStoreGraphs()361 bool EagerContext::ShouldStoreGraphs() {
362 mutex_lock ml(metadata_mu_);
363 return should_store_graphs_.load() || metadata_listener_ != nullptr;
364 }
365
ShouldStoreStepStats()366 bool EagerContext::ShouldStoreStepStats() {
367 mutex_lock ml(metadata_mu_);
368 return should_store_step_stats_.load() || metadata_listener_ != nullptr;
369 }
370
SetShouldStoreGraphs(bool value)371 void EagerContext::SetShouldStoreGraphs(bool value) {
372 mutex_lock ml(metadata_mu_);
373 should_store_graphs_.store(value);
374 if (!value || metadata_listener_ != nullptr) {
375 run_metadata_.Clear();
376 }
377 }
378
SetShouldStoreStepStats(bool value)379 void EagerContext::SetShouldStoreStepStats(bool value) {
380 mutex_lock ml(metadata_mu_);
381 should_store_step_stats_.store(value);
382 if (!value || metadata_listener_ != nullptr) {
383 run_metadata_.Clear();
384 }
385 }
386
387 namespace {
GetTaskName(Device * d,string * task_name)388 Status GetTaskName(Device* d, string* task_name) {
389 string ignored;
390 if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
391 return errors::InvalidArgument("Unable to parse device name: ", d->name());
392 }
393
394 return Status::OK();
395 }
396 } // namespace
397
398 #ifndef __ANDROID__
GetClientAndContextID(Device * device,eager::EagerClient ** client,uint64 * context_id)399 Status EagerContext::GetClientAndContextID(Device* device,
400 eager::EagerClient** client,
401 uint64* context_id) {
402 auto it = device_to_client_cache_.find(device);
403 if (it != device_to_client_cache_.end()) {
404 *client = it->second.first;
405 *context_id = it->second.second;
406 }
407 string device_task_name;
408 TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name));
409
410 *client = remote_eager_workers_->GetClient(device_task_name);
411
412 if (*client == nullptr) {
413 return errors::InvalidArgument(
414 "Unable to find eager client corresponding to device ", device->name());
415 }
416
417 auto context_iterator = remote_contexts_.find(device_task_name);
418 if (context_iterator == remote_contexts_.end()) {
419 return errors::Internal("Unable to find a context for handle on task: ",
420 device_task_name, ". This should not be possible");
421 }
422 *context_id = context_iterator->second;
423
424 device_to_client_cache_.insert({device, {*client, *context_id}});
425
426 return Status::OK();
427 }
428
StoreCollectiveOpsServer(std::unique_ptr<ServerInterface> server,DeviceMgr * device_mgr,CollectiveExecutorMgrInterface * rpc_collective_executor_mgr)429 Status EagerContext::StoreCollectiveOpsServer(
430 std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr,
431 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
432 collective_executor_mgr_.reset(nullptr);
433 unowned_collective_executor_mgr_ = rpc_collective_executor_mgr;
434
435 local_device_manager_.reset(nullptr);
436 local_unowned_device_manager_ = device_mgr;
437
438 devices_ = local_unowned_device_manager_->ListDevices();
439 devices_map_.clear();
440
441 InitDeviceMapAndAsync();
442 TF_RETURN_IF_ERROR(ClearCaches());
443
444 pflr_.reset(new ProcessFunctionLibraryRuntime(
445 local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
446 {}, thread_pool_.get()));
447
448 // Memory leak!
449 if (server_ != nullptr) {
450 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
451 "Servers don't support clean shutdown.";
452 server_.release();
453 }
454 server_ = std::move(server);
455
456 return Status::OK();
457 }
458
InitializeRemote(std::unique_ptr<ServerInterface> server,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DeviceMgr> remote_device_manager,const gtl::FlatMap<string,uint64> & remote_contexts,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs)459 Status EagerContext::InitializeRemote(
460 std::unique_ptr<ServerInterface> server,
461 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
462 std::unique_ptr<DeviceMgr> remote_device_manager,
463 const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
464 DeviceMgr* local_device_mgr, int keep_alive_secs) {
465 mutex_lock l(remote_state_mu_);
466
467 if (!remote_contexts_.empty()) {
468 CloseRemoteContexts();
469 }
470 remote_contexts_ = remote_contexts;
471
472 use_send_tensor_rpc_ =
473 ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false);
474
475 local_unowned_device_manager_ = local_device_mgr;
476 local_device_manager_ = nullptr;
477 pflr_.reset(new ProcessFunctionLibraryRuntime(
478 local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
479 {}, thread_pool_.get()));
480
481 devices_ = local_unowned_device_manager_->ListDevices();
482 devices_map_.clear();
483
484 if (rendezvous_ != nullptr) rendezvous_->Unref();
485 rendezvous_ = r;
486
487 // Memory leak!
488 if (server_ != nullptr) {
489 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
490 "Servers don't support clean shutdown.";
491 server_.release();
492 }
493
494 server_ = std::move(server);
495 remote_eager_workers_ = std::move(remote_eager_workers);
496
497 active_remote_contexts_.clear();
498 for (const auto& remote_context : remote_contexts_) {
499 active_remote_contexts_.insert(remote_context.second);
500 }
501
502 device_to_client_cache_.clear();
503 remote_device_manager_ = std::move(remote_device_manager);
504
505 InitDeviceMapAndAsync();
506
507 TF_RETURN_IF_ERROR(ClearCaches());
508
509 keep_alive_secs_ = keep_alive_secs;
510
511 sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
512
513 // Only schedule a single closure.
514 if (keep_alive_thread_ == nullptr) {
515 keep_alive_thread_.reset(
516 env_->StartThread({}, "EagerKeepAliveThread", [this]() {
517 while (true) {
518 {
519 {
520 mutex_lock l(keep_alive_thread_shutdown_mu_);
521 keep_alive_thread_cv_.wait_for(
522 l, std::chrono::seconds(sleep_for_secs_));
523
524 if (shutting_down_) {
525 return;
526 }
527 }
528 {
529 mutex_lock l(remote_state_mu_);
530 if (keep_alive_secs_ > 0) {
531 {
532 for (const auto& worker_and_context_id : remote_contexts_) {
533 auto* client = remote_eager_workers_->GetClient(
534 worker_and_context_id.first);
535
536 eager::KeepAliveRequest* request =
537 new eager::KeepAliveRequest;
538 eager::KeepAliveResponse* response =
539 new eager::KeepAliveResponse;
540
541 request->set_context_id(worker_and_context_id.second);
542 client->KeepAliveAsync(
543 request, response,
544 [request, response](const Status& s) {
545 delete request;
546 delete response;
547 });
548 }
549 }
550 }
551 }
552 }
553 }
554 }));
555 }
556 return Status::OK();
557 }
558 #endif
559
560 } // namespace tensorflow
561