1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <unordered_set>
25 #include <vector>
26
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/c/eager/immediate_execution_context.h"
30 #include "tensorflow/core/common_runtime/composite_device.h"
31 #include "tensorflow/core/common_runtime/device_factory.h"
32 #include "tensorflow/core/common_runtime/device_mgr.h"
33 #include "tensorflow/core/common_runtime/eager/custom_device.h"
34 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
35 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/example/example.pb.h"
41 #include "tensorflow/core/framework/collective.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/log_memory.h"
44 #include "tensorflow/core/framework/rendezvous.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/core/threadpool.h"
49 #include "tensorflow/core/lib/gtl/flatmap.h"
50 #include "tensorflow/core/lib/gtl/flatset.h"
51 #include "tensorflow/core/lib/gtl/inlined_vector.h"
52 #include "tensorflow/core/lib/gtl/map_util.h"
53 #include "tensorflow/core/platform/casts.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/fingerprint.h"
56 #include "tensorflow/core/platform/mutex.h"
57 #include "tensorflow/core/platform/platform.h"
58 #include "tensorflow/core/platform/status.h"
59 #include "tensorflow/core/platform/thread_annotations.h"
60 #include "tensorflow/core/public/session_options.h"
61 #include "tensorflow/core/public/version.h"
62 #include "tensorflow/core/util/device_name_utils.h"
63
64 // "tensorflow/core/platform/platform.h" must be included first before using
65 // IS_MOBILE_PLATFORM.
66 #if !defined(IS_MOBILE_PLATFORM)
67 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
68 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
69 #include "tensorflow/core/distributed_runtime/server_lib.h"
70 #include "tensorflow/core/distributed_runtime/worker_cache.h"
71 #include "tensorflow/core/distributed_runtime/worker_env.h"
72 #endif // !IS_MOBILE_PLATFORM
73
74 namespace tensorflow {
75
76 namespace eager {
77 // We need this forward declaration because we have circular dependency:
78 // Context -> RemoteMgr -> TensorHandle -> Context.
79 // TODO(fishx): Remove this once we remove Context dependency in TensorHandle.
80 class RemoteMgr;
81 } // namespace eager
82
83 class TensorHandle;
84 class EagerOperation;
85
86 class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
87 public:
88 static constexpr uint64 kInvalidContextId = 0;
89
NewContextId()90 static uint64 NewContextId() {
91 uint64 context_id = random::New64();
92 while (context_id == kInvalidContextId) {
93 context_id = random::New64();
94 }
95 return context_id;
96 }
97
98 EagerContext(
99 const SessionOptions& opts,
100 ContextDevicePlacementPolicy default_device_placement_policy, bool async,
101 /*const*/ DeviceMgr* device_mgr, bool device_mgr_owned,
102 /*const*/ Rendezvous* rendezvous,
103 DistributedFunctionLibraryRuntime* cluster_flr = nullptr,
104 CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr,
105 bool run_eager_op_as_function = false);
106
Release()107 void Release() override { Unref(); }
108
109 AbstractTensorInterface* CreateInt64Scalar(int64_t value) override;
110 AbstractTensorInterface* CreateUint64Scalar(uint64 value) override;
111 AbstractTensorInterface* CreateInt32Scalar(int32_t value) override;
112 AbstractTensorInterface* CreateFloatScalar(float value) override;
113 AbstractTensorInterface* CreateDoubleScalar(double value) override;
114 AbstractTensorInterface* CreateHalfScalar(Eigen::half value) override;
115 AbstractTensorInterface* CreateStringScalar(
116 tensorflow::tstring value) override;
117 AbstractTensorInterface* CreateComplex128Scalar(
118 tensorflow::complex128 value) override;
119 AbstractTensorInterface* CreateBoolScalar(bool value) override;
120
121 AbstractTensorInterface* CreateTensor(
122 DataType dtype, absl::Span<const int64> dim_sizes) override;
123 AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims,
124 int num_dims, void* data, size_t len,
125 MemoryReleaser memory_releaser,
126 void* memory_releaser_arg) override;
127
128 ImmediateExecutionTensorHandle* CreateLocalHandle(
129 AbstractTensorInterface* t) override;
130 // Create an abstract tensor handle from tensorflow::Tensor.
131 ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
132 tensorflow::Tensor& t, const char* d_name) override;
133 ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
134 ImmediateExecutionTensorHandle* handle, const char* device_name,
135 Status* status) override;
136 ImmediateExecutionOperation* CreateOperation() override;
137
138 // This is a virtual helper function to convert TFRT TensorHandle to
139 // tensorflow::TensorHandle. In current runtime EagerContext, just forward
140 // the input since the input tensor handle is already a
141 // tensorflow::TensorHandle.
142 ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
143 ImmediateExecutionTensorHandle* handle) override;
144
145 Status RegisterFunction(AbstractFunction* f) override;
146
147 bool UsesTFRT() override;
148
149 bool RunEagerOpAsFunction() const;
150
151 void ListDevices(std::vector<DeviceAttributes>* devices) override;
152
153 Status AddDevices(std::vector<std::unique_ptr<Device>> devices) override;
154
155 // Returns the function library runtime for the given device.
func_lib(const Device * d)156 FunctionLibraryRuntime* func_lib(const Device* d) const {
157 return pflr_->GetFLR(d->name());
158 }
159
pflr()160 ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); }
161
runner()162 std::function<void(std::function<void()>)>* runner() { return &runner_; }
163
164 // Specify a executor for this thread.
165 void SetExecutorForThread(EagerExecutor* executor) override;
166
prioritized_device_type_list()167 const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
168 const {
169 mutex_lock l(device_type_list_mu_);
170 return prioritized_device_type_list_;
171 }
172
173 // Clear pending nodes in thread executors and kernel caches.
174 void ClearCachesAndThreadExecutors() override;
175 // Clear pending nodes in default executor and kernel caches.
176 void ClearCachesAndDefaultExecutor();
177
178 // Sets the device placement policy for the current thread.
179 void SetThreadLocalDevicePlacementPolicy(
180 ContextDevicePlacementPolicy policy) override;
181
182 // Returns the device placement policy for the current thread.
183 ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override;
184
185 // Select an appropriate device for an operation.
186 //
187 // Given the preferred device for the operation, and the node_def, finds the
188 // best suitable device for the operation in this context.
189 //
190 // The preferred device is specified as a `ParsedName` containing the elements
191 // (details) that the resulting device should match. If there are no such
192 // devices, and the context currently allows soft device placement, a suitable
193 // device not matching `preferred` will be chosen.
194 //
195 // The chosen device is stored in the `device` argument. The argument is not
196 // modified unless this method returns `Status::OK()`.
197 Status SelectDevice(DeviceNameUtils::ParsedName preferred,
198 const NodeDef& ndef, Device** out) const;
199
200 bool FindFunctionByName(const string& name) const;
201
202 Status FindFunctionOpData(const string& name,
203 const tensorflow::OpRegistrationData** op_data);
204
205 const FunctionDef* FindFunctionDef(const string& name) const override;
206
HostCPU()207 Device* HostCPU() const { return host_cpu_device_; }
CanonicalDevice(Device * d)208 Device* CanonicalDevice(Device* d) const {
209 return HostCPU() == d ? nullptr : d;
210 }
HostCPUParsedName()211 const DeviceNameUtils::ParsedName& HostCPUParsedName() const override {
212 return HostCPU()->parsed_name();
213 }
214
HostCPUName()215 const string& HostCPUName() const override { return HostCPU()->name(); }
216
GetGraphCollector()217 GraphCollector* GetGraphCollector() { return &graph_collector_; }
218
219 EagerExecutor& Executor() override;
220
221 // Add the given `fdef` to the local FunctionLibraryDefinition. And add an
222 // entry to the KernelAndDevice cache for it if it's not exist.
223 Status AddFunctionDef(const FunctionDef& fdef) override;
224
225 Status AddFunctionDefWithStackTraces(
226 const FunctionDef& fdef, const StackTracesMap& stack_traces) override;
227
228 // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
229 // it to the local FunctionLibraryDefinition as well, but no need to add it
230 // to the KernelAndDevice cache since they won't be executed as
231 // KernelAndDevices.
232 Status AddFunctionDef(const FunctionDef& fdef,
233 const FunctionDefLibrary& library,
234 bool add_to_local_only = false,
235 const StackTracesMap& stack_traces = {});
236
237 const FunctionDef* GetFunctionDef(const string& function_name);
238
239 std::vector<string> ListFunctionNames() override;
240
241 Status RemoveFunction(const string& func) override;
242
243 // Wait for pending nodes to be finished in local executors (including context
244 // default executor and thread executors) and executors on remote workers.
245 // Return combined status of remote executors. If there are multiple errors,
246 // the Status code will be the same as the first remote executor that has
247 // errors, and the error message will be combined from all executors.
248 Status SyncExecutors();
249
AsyncWait()250 Status AsyncWait() override { return SyncExecutors(); }
251
252 core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
253
254 void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
255
LogDevicePlacement()256 bool LogDevicePlacement() const { return log_device_placement_; }
SetLogDevicePlacement(bool enable)257 void SetLogDevicePlacement(bool enable) override {
258 log_device_placement_ = enable;
259 }
260
261 // When tensor transfer across functions/eager executions using send/recv ops
262 // are required, `reuse_rendezvous_for_functions_` can be set to true so that
263 // function executions and eager executions use the same rendezvous instance,
264 // instead of creating new instance per function calls.
SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions)265 void SetReuseRendezvousForFunctions(
266 bool reuse_rendezvous_for_functions) override {
267 reuse_rendezvous_for_functions_ = reuse_rendezvous_for_functions;
268 }
GetReuseRendezvousForFunctions()269 bool GetReuseRendezvousForFunctions() const {
270 return reuse_rendezvous_for_functions_;
271 }
272
AllowSoftPlacement()273 bool AllowSoftPlacement() const { return allow_soft_placement_; }
SetAllowSoftPlacement(bool enable)274 void SetAllowSoftPlacement(bool enable) override {
275 allow_soft_placement_ = enable;
276 }
LogMemory()277 bool LogMemory() const { return log_memory_; }
278
GetRendezvous()279 Rendezvous* GetRendezvous() const { return rendezvous_; }
280
ResetGlobalRendezvousForFunction()281 void ResetGlobalRendezvousForFunction() override {
282 mutex_lock l(global_rendezvous_mu_);
283 global_rendezvous_for_functions_ =
284 core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
285 }
286
287 // Returns a function which maps from step_id to rendezvous. This closure
288 // respects the value of `SetReuseRendezvousForFunctions` at the time the
289 // closure was created, which allows the setting to be toggled around async op
290 // launches.
291 //
292 // The caller of the returned function owns a reference to the resulting
293 // Rendezvous.
RendezvousCreator()294 std::function<Rendezvous*(int64_t)> RendezvousCreator() {
295 if (reuse_rendezvous_for_functions_) {
296 return [this](int64_t step_id) {
297 mutex_lock l(global_rendezvous_mu_);
298 global_rendezvous_for_functions_->Ref();
299 return global_rendezvous_for_functions_.get();
300 };
301 } else {
302 return [this](int64_t step_id) { return CreateRendezvous(step_id); };
303 }
304 }
305
collective_executor_mgr()306 CollectiveExecutorMgrInterface* collective_executor_mgr() {
307 return collective_executor_mgr_.Get();
308 }
GetCollectiveExecutorHandle()309 std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
310 return std::unique_ptr<CollectiveExecutor::Handle>(
311 new CollectiveExecutor::Handle(
312 collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/));
313 }
314
local_device_mgr()315 tensorflow::DeviceMgr* local_device_mgr() const {
316 return local_device_manager_.Get();
317 }
remote_device_mgr()318 const tensorflow::DynamicDeviceMgr* remote_device_mgr() const {
319 return remote_device_manager_.Get();
320 }
321
GetOwnedRemoteDeviceMgr()322 tensorflow::DynamicDeviceMgr* GetOwnedRemoteDeviceMgr() {
323 return remote_device_manager_.GetOwned();
324 }
325
ListLocalTfDevices()326 std::vector<Device*> ListLocalTfDevices() override {
327 return local_device_mgr()->ListDevices();
328 }
329
330 // TODO(apassos) clean up RunMetadata storage.
MetadataMu()331 mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
332 bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
333 void SetShouldStoreGraphs(bool value) override;
RunMetadataProto()334 RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) {
335 return run_metadata_.get();
336 }
337 std::unique_ptr<RunMetadata> ExportRunMetadata() override
338 TF_LOCKS_EXCLUDED(metadata_mu_);
339
340 void StartStep() override;
341 void EndStep() override;
342 ScopedStepContainer* StepContainer();
343
FuncLibDef()344 FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
345
346 #if !defined(IS_MOBILE_PLATFORM)
347 // Assign the EagerClient pointer to `client` based on the given device / task
348 // name, and increment the refcount of the client. The reference ownership is
349 // transferred to the caller, and the unref should automatically happen when
350 // destructing the RefCountPtr object at the caller's side.
351 // `client` must not be initialized or holding a reference of another object
352 // before calling this method.
353 Status GetClient(Device* device,
354 core::RefCountPtr<eager::EagerClient>* client);
355 Status GetClient(const DeviceNameUtils::ParsedName& device_name,
356 core::RefCountPtr<eager::EagerClient>* client);
357 Status GetClient(const string& remote_task,
358 core::RefCountPtr<eager::EagerClient>* client);
359
360 uint64 GetContextId() const;
361 uint64 GetContextViewId() const;
362 void IncrementContextViewId();
363
364 Status EnableCollectiveOps(const ServerDef& server_def) override;
365
366 // TODO(nareshmodi): Encapsulate remote state into a separate
367 // class/struct.
368 //
369 // Enables the eager context to communicate with remote devices. When
370 // initializing with this method, this context will be the primary context,
371 // which will kill all its remote contexts in shutdown.
372 //
373 // - server: A ServerInterface that exports the tensorflow.WorkerService.
374 // Note that this class expects the server to already have been started.
375 // - remote_eager_workers: A cache from which we can get "EagerClient"s to
376 // communicate with remote eager services.
377 // - remote_device_mgr: A DeviceMgr* which contains all remote devices
378 // (should contain no local devices).
379 // - remote_contexts: A vector containing task names.
380 // TODO(b/184375824): clean up parameter order for better readability.
381 Status InitializeRemoteMaster(
382 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
383 std::shared_ptr<WorkerSession> worker_session,
384 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
385 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
386 const std::vector<string>& remote_contexts, uint64 context_id,
387 /*const*/ Rendezvous* r, /*const*/ DeviceMgr* local_device_mgr,
388 int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
389 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
390 remote_mgr);
391
392 // Update an existing master context with a new set of remote workers (i.e., a
393 // new "view" of cluster membership. Similar to InitializeRemoteMaster but
394 // this will keep the current context_id and increment a context_view_id, will
395 // keep the current resource manager so that resources from the previous view
396 // can still be accessed, and will automatically register existing functions
397 // if there are newly added hosts.
398 Status UpdateRemoteMaster(
399 uint64 context_id,
400 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
401 const std::vector<string>& add_remote_contexts,
402 const std::vector<string>& remove_remote_contexts);
403
404 // Similar with InitializeRemoteMaster but this context will not kill remote
405 // contexts in shutdown.
406 Status InitializeRemoteWorker(
407 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
408 DynamicDeviceMgr* remote_device_mgr,
409 const std::vector<string>& remote_contexts, uint64 context_id,
410 uint64 context_view_id,
411 std::function<Rendezvous*(const int64_t)> rendezvous_creator,
412 DistributedFunctionLibraryRuntime* cluster_flr,
413 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
414 remote_mgr,
415 std::function<void()> resource_deallocator);
416
417 // Similar with InitializeRemoteWorker but will reuse existing context and
418 // increment context_view_id.
419 Status UpdateRemoteWorker(
420 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
421 const std::vector<string>& remote_contexts, uint64 context_id);
422
423 Status StoreCollectiveOpsServer(
424 std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
425 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
426
427 // For the specified remote worker, preprocess and set its device filters.
428 Status SetRemoteDeviceFilters(const string& remote_worker,
429 const std::vector<string>& device_filters);
430
431 // For the specified remote worker, apply the stored device filters to the
432 // list of device attributes following these rules:
433 // (1) if the remote worker does not have device filters, all devices are
434 // visible to the worker;
435 // (2) if the device is on the remote worker, then it is visible;
436 // (3) if the device matches at least one device filter, then it is visible.
437 // The result is saved as a boolean vector of the same length (i.e.,
438 // filtered_device_mask) indicating whether each of the devices is visible to
439 // the remote worker.
440 void FilterDevicesForRemoteWorkers(
441 const string& remote_worker,
442 const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
443 std::vector<bool>* filtered_device_mask);
444
445 // TODO(fishx): Remove the custom deleter once we remove forward declaration.
446 const std::unique_ptr<eager::RemoteMgr,
447 std::function<void(eager::RemoteMgr*)>>&
RemoteMgr()448 RemoteMgr() {
449 return remote_mgr_;
450 }
451
452 // If true, then tensors should be shipped across processes via the
453 // EagerService.Enqueue(SendTensorOp). If false, _Send/_Recv ops should be
454 // used instead (which in-turn use WorkerService.RecvTensor RPCs).
UseSendTensorRPC()455 bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
456
GetServer()457 tensorflow::ServerInterface* GetServer() { return server_.get(); }
458
459 // For LLVM style RTTI.
classof(const AbstractContext * ptr)460 static bool classof(const AbstractContext* ptr) {
461 return ptr->getKind() == kEager;
462 }
463
464 // Function to support distributed C API.
SetDistributedManager(std::unique_ptr<ImmediateExecutionDistributedManager> distributed)465 void SetDistributedManager(
466 std::unique_ptr<ImmediateExecutionDistributedManager> distributed)
467 override {
468 distributed_manager_ = std::move(distributed);
469 }
GetDistributedManager()470 ImmediateExecutionDistributedManager* GetDistributedManager() override {
471 return distributed_manager_.get();
472 }
473
474 // May only be used during multi-client setup so that a RemoteRendezvous
475 // can be initialized instead of defaulting to the IntraProcessRendezvous.
476 void SetWorkerEnv(WorkerEnv* worker_env,
477 std::shared_ptr<WorkerSession> worker_session);
478 #endif // IS_MOBILE_PLATFORM
479
480 // Closes remote eager contexts, waits for all RPCs to finish, and
481 // destroys the EagerClientCache. No RPCs can be made through this context
482 // after this method has been called.
483 // This method exists to aid a clean shutdown. It causes all RPCs to finish
484 // and remote TensorHandles to release their references to this context.
485 // To avoid deadlocks, this method must not be called on the thread
486 // processing RPCs because it makes RPCs and waits for their completion.
487 //
488 // On mobile, it just cleans the caches.
489 void WaitForAndCloseRemoteContexts();
490
PinSmallOpsToCPU()491 bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; }
492
TFEnv()493 tensorflow::Env* TFEnv() const { return env_; }
494
495 Status FindDeviceFromName(const char* device_name, Device** device) const;
496
497 Status FindCompositeDeviceFromName(StringPiece device_name,
498 CompositeDevice** device) const;
499
500 Status RegisterCustomDevice(const string& name,
501 std::unique_ptr<CustomDevice> device) override;
502
GetCustomDeviceOpHandler()503 CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
504 return custom_device_op_handler_;
505 };
506
507 // Find or create a composite device with the given `underlying_devices` and
508 // `device_name` (if not empty).
509 Status FindOrCreateCompositeDevice(
510 const std::vector<string>& underlying_devices, const string& device_name,
511 CompositeDevice** composite_device);
512
513 bool OnSameTask(const Device* first, const Device* second) const;
514 // Gets the CPU device on the task of device.
515 Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
516
session_options()517 const SessionOptions& session_options() const { return opts_; }
518 void InitPrioritizedDeviceTypeList();
519
520 private:
CreateRendezvous(int64_t step_id)521 Rendezvous* CreateRendezvous(int64_t step_id) const {
522 if (rendezvous_creator_ != nullptr) {
523 return rendezvous_creator_(step_id);
524 }
525
526 #if !defined(IS_MOBILE_PLATFORM)
527 if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) {
528 auto* remote_r = worker_env_->rendezvous_mgr->Find(step_id);
529 remote_r->Initialize(worker_session_.get()).IgnoreError();
530 return remote_r;
531 }
532 #endif
533
534 if (remote_device_mgr() == nullptr) {
535 return new IntraProcessRendezvous(local_device_mgr());
536 }
537
538 return nullptr;
539 }
540
541 ~EagerContext() override;
542
543 Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
544 Status RegisterExistingFunctionsOnRemoteWorkers(
545 const std::vector<string>& remote_workers);
546
547 void ResetPFLR(const DeviceMgr* device_mgr, Env* env,
548 const ConfigProto* config, int graph_def_version,
549 const FunctionLibraryDefinition* lib_def,
550 const OptimizerOptions& optimizer_options,
551 thread::ThreadPool* thread_pool = nullptr,
552 DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
553
554 void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr);
555
556 void ClearResourceContainer(const string& name);
557
558 template <typename T>
559 struct OwnedOrUnownedHelper {
560 public:
OwnedOrUnownedHelperOwnedOrUnownedHelper561 OwnedOrUnownedHelper() {}
562 explicit OwnedOrUnownedHelper(T* object, const bool owned = false) {
563 Reset(object, owned);
564 }
565
ResetOwnedOrUnownedHelper566 void Reset(std::unique_ptr<T> object) {
567 owned_object = std::move(object);
568 unowned_object_ptr = nullptr;
569 }
570
571 void Reset(T* object, const bool owned = false) {
572 if (owned) {
573 owned_object.reset(object);
574 unowned_object_ptr = nullptr;
575 } else {
576 owned_object.reset(nullptr);
577 unowned_object_ptr = object;
578 }
579 }
580
OwnedOwnedOrUnownedHelper581 bool Owned() const { return owned_object != nullptr; }
582
GetOwnedOwnedOrUnownedHelper583 T* GetOwned() const { return owned_object.get(); }
GetOwnedOrUnownedHelper584 T* Get() const {
585 return owned_object ? owned_object.get() : unowned_object_ptr;
586 }
587
588 std::unique_ptr<T> owned_object = nullptr;
589 T* unowned_object_ptr = nullptr;
590 };
591
592 SessionOptions opts_;
593 const ContextDevicePlacementPolicy default_device_placement_policy_;
594
595 // Note: we cannot use C++11 thread_local here as there is no concept of a
596 // thread-local-object-local variable in C++11.
597 mutable mutex policy_map_mu_;
598 std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
599 device_placement_policy_ TF_GUARDED_BY(policy_map_mu_);
600
601 OwnedOrUnownedHelper<DeviceMgr> local_device_manager_;
602 // Maintain copy of all previously created local device managers.
603 std::vector<std::unique_ptr<DeviceMgr>> old_local_device_managers_;
604
605 // Unowned DynamicDeviceMgr is set on remote worker to allow running
606 // multi-device function on remote worker.
607 OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
608
609 Device* host_cpu_device_; // Owned by device_manager
610 mutable mutex device_type_list_mu_;
611 std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list_
612 TF_GUARDED_BY(device_type_list_mu_);
613 Rendezvous* rendezvous_;
614 std::function<Rendezvous*(const int64_t)> rendezvous_creator_;
615 CustomDeviceOpHandler custom_device_op_handler_;
616
617 mutable mutex composite_devices_mu_;
618 // Maps from the fingerprint of a set of device names to a virtual
619 // CompositeDevice.
620 // TODO(b/145922293): Consider taking device names as keys.
621 absl::flat_hash_map<uint64, std::unique_ptr<CompositeDevice>>
622 composite_devices_ ABSL_GUARDED_BY(composite_devices_mu_);
623
Global()624 FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
625
626 std::unique_ptr<thread::ThreadPool> thread_pool_;
627
628 // EagerContext owns the DistributedFunctionLibraryRuntime(
629 // EagerClusterFunctionLibraryRuntime) if using EagerService for remote
630 // function execution (lazy_copy_function_remote_inputs_=true).
631 OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_;
632 // One FunctionLibraryRuntime per device.
633 // func_libs[i] is the FunctionLibraryRuntime corresponding to
634 // session->devices[i].
635 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
636
637 std::function<void(std::function<void()>)> runner_;
638
639 mutex cache_mu_;
640 struct RegisteredFunction : public core::RefCounted {
~RegisteredFunctionRegisteredFunction641 ~RegisteredFunction() override {}
642
643 std::unique_ptr<std::vector<Fprint128>> cached_kernel_keys;
644 };
645 std::unordered_map<Fprint128, core::RefCountPtr<KernelAndDevice>,
646 Fprint128Hasher>
647 kernel_cache_ TF_GUARDED_BY(cache_mu_);
648 std::unordered_map<string, RegisteredFunction*> registered_functions_
649 TF_GUARDED_BY(cache_mu_);
650
651 // Whether we should compute RunMetadata.
652 std::atomic<bool> should_store_graphs_{false};
653 mutex metadata_mu_;
654 std::unique_ptr<RunMetadata> run_metadata_ TF_GUARDED_BY(metadata_mu_);
655 GraphCollector graph_collector_;
656 std::atomic<bool> log_device_placement_;
657 std::atomic<bool> allow_soft_placement_;
658
659 // Information related to step containers.
660 std::atomic<int> num_active_steps_;
661 std::unique_ptr<ScopedStepContainer> step_container_
662 TF_GUARDED_BY(metadata_mu_);
663
664 EagerExecutor default_executor_;
665 mutable mutex executor_map_mu_;
666 // Not owned.
667 std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
668 TF_GUARDED_BY(executor_map_mu_);
669 std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>>
670 has_cleanup_ TF_GUARDED_BY(executor_map_mu_);
671
672 const bool log_memory_;
673
674 // Whether to use same rendezvous instance across function/eager executions.
675 bool reuse_rendezvous_for_functions_ = false;
676 mutable mutex global_rendezvous_mu_;
677 core::RefCountPtr<Rendezvous> global_rendezvous_for_functions_
678 TF_GUARDED_BY(global_rendezvous_mu_);
679
680 Env* const env_;
681
682 OwnedOrUnownedHelper<CollectiveExecutorMgrInterface> collective_executor_mgr_;
683
684 #if !defined(IS_MOBILE_PLATFORM)
685 std::vector<string> GetRemoteContexts() TF_LOCKS_EXCLUDED(remote_state_mu_);
686 bool IsRemoteContextsEmpty() TF_LOCKS_EXCLUDED(remote_state_mu_);
687 void CloseAndClearAllRemoteContexts();
688 void CloseRemoteContexts(const std::vector<string>& remote_contexts,
689 uint64 context_id, uint64 context_view_id);
690
691 // TODO(b/184375824): clean up parameter order for better readability.
692 Status SetMasterContextState(
693 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
694 std::shared_ptr<WorkerSession> worker_session,
695 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
696 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
697 uint64 context_id, uint64 context_view_id, /*const*/ Rendezvous* r,
698 /*const*/ DeviceMgr* local_device_mgr, int keep_alive_secs,
699 DistributedFunctionLibraryRuntime* cluster_flr,
700 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
701 remote_mgr);
702
703 // The server_ is not const since we release it when the context is destroyed.
704 // Therefore the server_ object is not marked as const (even though it should
705 // be).
706 std::unique_ptr<ServerInterface> server_;
707 WorkerEnv* worker_env_ = nullptr;
708 std::shared_ptr<WorkerSession> worker_session_;
709
710 mutable mutex remote_state_mu_;
711
712 uint64 context_id_ TF_GUARDED_BY(remote_state_mu_);
713 // The view id of an eager context should be set to 0 when context is created,
714 // and continuously incremented when context with the same context_id gets
715 // updated. The view id should be consistent between master and workers.
716 uint64 context_view_id_ TF_GUARDED_BY(remote_state_mu_);
717 std::vector<string> remote_contexts_ TF_GUARDED_BY(remote_state_mu_);
718 std::unique_ptr<eager::EagerClientCache> remote_eager_workers_
719 TF_GUARDED_BY(remote_state_mu_);
720
721 int keep_alive_secs_ TF_GUARDED_BY(remote_state_mu_);
722 std::atomic<int> sleep_for_secs_;
723
724 std::unique_ptr<Thread> keep_alive_thread_;
725 mutex keep_alive_thread_shutdown_mu_;
726 condition_variable keep_alive_thread_cv_;
727 bool shutting_down_ TF_GUARDED_BY(keep_alive_thread_shutdown_mu_) = false;
728
729 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
730 remote_mgr_;
731 bool is_master_ TF_GUARDED_BY(remote_state_mu_);
732
733 // Maps from a remote worker to a list of parsed device filters.
734 std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>>
735 cluster_device_filters_ TF_GUARDED_BY(remote_state_mu_);
736
737 // A distributed manager that helps setup, update, and check liveness of
738 // member tasks in the cluster.
739 std::unique_ptr<ImmediateExecutionDistributedManager> distributed_manager_;
740
741 #endif // IS_MOBILE_PLATFORM
742
743 // For a multi device function, the target device of each input is unknown
744 // until the function is instantiated on the default function device.
745 // If false, eagerly copy all remote inputs to the default function device;
746 // if true, lazily copy remote inputs to their target devices to avoid
747 // redundant copies.
748 bool lazy_copy_function_remote_inputs_ = false;
749 bool use_send_tensor_rpc_;
750 const bool pin_small_ops_to_cpu_;
751
752 // Function that will be invoked in destructor to deallocate resources related
753 // to this context.
754 std::function<void()> resource_deallocator_ = nullptr;
755 bool run_eager_op_as_function_;
756 };
757
ContextFromInterface(ImmediateExecutionContext * context)758 inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) {
759 return down_cast<EagerContext*>(context);
760 }
761
762 namespace internal {
763 struct EagerContextDeleter {
operatorEagerContextDeleter764 void operator()(EagerContext* p) const {
765 if (p != nullptr) {
766 p->Release();
767 }
768 }
769 };
770 } // namespace internal
771
772 using EagerContextPtr =
773 std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
774
775 } // namespace tensorflow
776
777 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
778