1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <unordered_set>
25 #include <vector>
26
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/c/eager/immediate_execution_context.h"
30 #include "tensorflow/core/common_runtime/composite_device.h"
31 #include "tensorflow/core/common_runtime/device_factory.h"
32 #include "tensorflow/core/common_runtime/device_mgr.h"
33 #include "tensorflow/core/common_runtime/eager/custom_device.h"
34 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
35 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/example/example.pb.h"
41 #include "tensorflow/core/framework/collective.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/log_memory.h"
44 #include "tensorflow/core/framework/rendezvous.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/core/threadpool.h"
49 #include "tensorflow/core/lib/gtl/flatmap.h"
50 #include "tensorflow/core/lib/gtl/flatset.h"
51 #include "tensorflow/core/lib/gtl/inlined_vector.h"
52 #include "tensorflow/core/lib/gtl/map_util.h"
53 #include "tensorflow/core/platform/casts.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/fingerprint.h"
56 #include "tensorflow/core/platform/mutex.h"
57 #include "tensorflow/core/platform/platform.h"
58 #include "tensorflow/core/platform/status.h"
59 #include "tensorflow/core/platform/thread_annotations.h"
60 #include "tensorflow/core/public/session_options.h"
61 #include "tensorflow/core/public/version.h"
62 #include "tensorflow/core/util/device_name_utils.h"
63
64 // "tensorflow/core/platform/platform.h" must be included first before using
65 // IS_MOBILE_PLATFORM.
66 #if !defined(IS_MOBILE_PLATFORM)
67 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
68 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
69 #include "tensorflow/core/distributed_runtime/server_lib.h"
70 #include "tensorflow/core/distributed_runtime/worker_cache.h"
71 #include "tensorflow/core/distributed_runtime/worker_env.h"
72 #endif // !IS_MOBILE_PLATFORM
73
74 namespace tensorflow {
75
76 namespace eager {
77 // We need this forward declaration because we have circular dependency:
78 // Context -> RemoteMgr -> TensorHandle -> Context.
79 // TODO(fishx): Remove this once we remove Context dependency in TensorHandle.
80 class RemoteMgr;
81 } // namespace eager
82
83 class TensorHandle;
84 class EagerOperation;
85
86 class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
87 public:
88 static constexpr uint64 kInvalidContextId = 0;
89
NewContextId()90 static uint64 NewContextId() {
91 uint64 context_id = random::New64();
92 while (context_id == kInvalidContextId) {
93 context_id = random::New64();
94 }
95 return context_id;
96 }
97
98 EagerContext(const SessionOptions& opts,
99 ContextDevicePlacementPolicy default_device_placement_policy,
100 bool async, const DeviceMgr* device_mgr, bool device_mgr_owned,
101 Rendezvous* rendezvous,
102 DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
103
Release()104 void Release() override { Unref(); }
105
106 AbstractTensorInterface* CreateInt64Scalar(int64 value) override;
107 AbstractTensorInterface* CreateUint64Scalar(uint64 value) override;
108 AbstractTensorInterface* CreateInt32Scalar(int32 value) override;
109 AbstractTensorInterface* CreateFloatScalar(float value) override;
110 AbstractTensorInterface* CreateDoubleScalar(double value) override;
111 AbstractTensorInterface* CreateHalfScalar(Eigen::half value) override;
112 AbstractTensorInterface* CreateStringScalar(
113 tensorflow::tstring value) override;
114 AbstractTensorInterface* CreateComplex128Scalar(
115 tensorflow::complex128 value) override;
116 AbstractTensorInterface* CreateBoolScalar(bool value) override;
117
118 AbstractTensorInterface* CreateTensor(
119 DataType dtype, absl::Span<const int64> dim_sizes) override;
120 AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims,
121 int num_dims, void* data, size_t len,
122 MemoryReleaser memory_releaser,
123 void* memory_releaser_arg) override;
124
125 ImmediateExecutionTensorHandle* CreateLocalHandle(
126 AbstractTensorInterface* t) override;
127 // Create an abstract tensor handle from tensorflow::Tensor.
128 ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
129 tensorflow::Tensor& t, const char* d_name) override;
130 ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
131 ImmediateExecutionTensorHandle* handle, const char* device_name,
132 Status* status) override;
133 ImmediateExecutionOperation* CreateOperation() override;
134
135 // This is a virtual helper function to convert TFRT TensorHandle to
136 // tensorflow::TensorHandle. In current runtime EagerContext, just forward
137 // the input since the input tensor handle is already a
138 // tensorflow::TensorHandle.
139 ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
140 ImmediateExecutionTensorHandle* handle) override;
141
142 Status RegisterFunction(AbstractFunction* f) override;
143
144 bool UsesTFRT() override;
145
146 void ListDevices(std::vector<DeviceAttributes>* devices) override;
147
148 // Returns the function library runtime for the given device.
func_lib(const Device * d)149 FunctionLibraryRuntime* func_lib(const Device* d) const {
150 return pflr_->GetFLR(d->name());
151 }
152
pflr()153 ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); }
154
runner()155 std::function<void(std::function<void()>)>* runner() { return &runner_; }
156
157 // Specify a executor for this thread.
158 void SetExecutorForThread(EagerExecutor* executor) override;
159
prioritized_device_type_list()160 const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
161 const {
162 mutex_lock l(device_type_list_mu_);
163 return prioritized_device_type_list_;
164 }
165
166 // Clear pending nodes in thread executors and kernel caches.
167 void ClearCachesAndThreadExecutors() override;
168 // Clear pending nodes in default executor and kernel caches.
169 void ClearCachesAndDefaultExecutor();
170
171 // Sets the device placement policy for the current thread.
172 void SetThreadLocalDevicePlacementPolicy(
173 ContextDevicePlacementPolicy policy) override;
174
175 // Returns the device placement policy for the current thread.
176 ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override;
177
178 // Select an appropriate device for an operation.
179 //
180 // Given the preferred device for the operation, and the node_def, finds the
181 // best suitable device for the operation in this context.
182 //
183 // The preferred device is specified as a `ParsedName` containing the elements
184 // (details) that the resulting device should match. If there are no such
185 // devices, and the context currently allows soft device placement, a suitable
186 // device not matching `preferred` will be chosen.
187 //
188 // The chosen device is stored in the `device` argument. The argument is not
189 // modified unless this method returns `Status::OK()`.
190 Status SelectDevice(DeviceNameUtils::ParsedName preferred,
191 const NodeDef& ndef, Device** out) const;
192
193 bool FindFunctionByName(const string& name) const;
194
195 Status FindFunctionOpData(const string& name,
196 const tensorflow::OpRegistrationData** op_data);
197
198 const FunctionDef* FindFunctionDef(const string& name) const override;
199
HostCPU()200 Device* HostCPU() const { return host_cpu_device_; }
CanonicalDevice(Device * d)201 Device* CanonicalDevice(Device* d) const {
202 return HostCPU() == d ? nullptr : d;
203 }
HostCPUParsedName()204 const DeviceNameUtils::ParsedName& HostCPUParsedName() const override {
205 return HostCPU()->parsed_name();
206 }
207
HostCPUName()208 const string& HostCPUName() const override { return HostCPU()->name(); }
209
GetGraphCollector()210 GraphCollector* GetGraphCollector() { return &graph_collector_; }
211
212 EagerExecutor& Executor() override;
213
214 // Add the given `fdef` to the local FunctionLibraryDefinition. And add an
215 // entry to the KernelAndDevice cache for it if it's not exist.
216 Status AddFunctionDef(const FunctionDef& fdef) override;
217
218 Status AddFunctionDefWithStackTraces(
219 const FunctionDef& fdef, const StackTracesMap& stack_traces) override;
220
221 // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
222 // it to the local FunctionLibraryDefinition as well, but no need to add it
223 // to the KernelAndDevice cache since they won't be executed as
224 // KernelAndDevices.
225 Status AddFunctionDef(const FunctionDef& fdef,
226 const FunctionDefLibrary& library,
227 bool add_to_local_only = false,
228 const StackTracesMap& stack_traces = {});
229
230 const FunctionDef* GetFunctionDef(const string& function_name);
231
232 std::vector<string> ListFunctionNames() override;
233
234 Status RemoveFunction(const string& func) override;
235
236 // Wait for pending nodes to be finished in local executors (including context
237 // default executor and thread executors) and executors on remote workers.
238 // Return combined status of remote executors. If there are multiple errors,
239 // the Status code will be the same as the first remote executor that has
240 // errors, and the error message will be combined from all executors.
241 Status SyncExecutors();
242
AsyncWait()243 Status AsyncWait() override { return SyncExecutors(); }
244
245 core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
246
247 void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
248
LogDevicePlacement()249 bool LogDevicePlacement() const { return log_device_placement_; }
SetLogDevicePlacement(bool enable)250 void SetLogDevicePlacement(bool enable) override {
251 log_device_placement_ = enable;
252 }
253
254 // When tensor transfer across functions/eager executions using send/recv ops
255 // are required, `reuse_rendezvous_for_functions_` can be set to true so that
256 // function executions and eager executions use the same rendezvous instance,
257 // instead of creating new instance per function calls.
SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions)258 void SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions) {
259 reuse_rendezvous_for_functions_ = reuse_rendezvous_for_functions;
260 }
GetReuseRendezvousForFunctions()261 bool GetReuseRendezvousForFunctions() const {
262 return reuse_rendezvous_for_functions_;
263 }
264
AllowSoftPlacement()265 bool AllowSoftPlacement() const { return allow_soft_placement_; }
SetAllowSoftPlacement(bool enable)266 void SetAllowSoftPlacement(bool enable) override {
267 allow_soft_placement_ = enable;
268 }
LogMemory()269 bool LogMemory() const { return log_memory_; }
270
GetRendezvous()271 Rendezvous* GetRendezvous() const { return rendezvous_; }
272
273 // Returns a function which maps from step_id to rendezvous. This closure
274 // respects the value of `SetReuseRendezvousForFunctions` at the time the
275 // closure was created, which allows the setting to be toggled around async op
276 // launches.
277 //
278 // The caller of the returned function owns a reference to the resulting
279 // Rendezvous.
RendezvousCreator()280 std::function<Rendezvous*(int64)> RendezvousCreator() {
281 if (reuse_rendezvous_for_functions_) {
282 return [this](int64 step_id) {
283 // Increment reference count as `rendezvous_` will be unref'ed after
284 // function execution.
285 rendezvous_->Ref();
286 return rendezvous_;
287 };
288 } else {
289 return [this](int64 step_id) { return CreateRendezvous(step_id); };
290 }
291 }
292
collective_executor_mgr()293 CollectiveExecutorMgrInterface* collective_executor_mgr() {
294 return collective_executor_mgr_.Get();
295 }
GetCollectiveExecutorHandle()296 std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
297 return std::unique_ptr<CollectiveExecutor::Handle>(
298 new CollectiveExecutor::Handle(
299 collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/));
300 }
301
local_device_mgr()302 const tensorflow::DeviceMgr* local_device_mgr() const {
303 return local_device_manager_.Get();
304 }
remote_device_mgr()305 const tensorflow::DynamicDeviceMgr* remote_device_mgr() const {
306 return remote_device_manager_.Get();
307 }
308
GetOwnedRemoteDeviceMgr()309 tensorflow::DynamicDeviceMgr* GetOwnedRemoteDeviceMgr() {
310 return remote_device_manager_.GetOwned();
311 }
312
ListLocalTfDevices()313 std::vector<Device*> ListLocalTfDevices() override {
314 return local_device_mgr()->ListDevices();
315 }
316
317 // TODO(apassos) clean up RunMetadata storage.
MetadataMu()318 mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
319 bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
320 void SetShouldStoreGraphs(bool value) override;
RunMetadataProto()321 RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) {
322 return run_metadata_.get();
323 }
324 std::unique_ptr<RunMetadata> ExportRunMetadata() override
325 TF_LOCKS_EXCLUDED(metadata_mu_);
326
327 void StartStep() override;
328 void EndStep() override;
329 ScopedStepContainer* StepContainer();
330
FuncLibDef()331 FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
332
333 #if !defined(IS_MOBILE_PLATFORM)
334 // Assign the EagerClient pointer to `client` based on the given device / task
335 // name, and increment the refcount of the client. The reference ownership is
336 // transferred to the caller, and the unref should automatically happen when
337 // destructing the RefCountPtr object at the caller's side.
338 // `client` must not be initialized or holding a reference of another object
339 // before calling this method.
340 Status GetClient(Device* device,
341 core::RefCountPtr<eager::EagerClient>* client);
342 Status GetClient(const DeviceNameUtils::ParsedName& device_name,
343 core::RefCountPtr<eager::EagerClient>* client);
344 Status GetClient(const string& remote_task,
345 core::RefCountPtr<eager::EagerClient>* client);
346
347 uint64 GetContextId() const;
348 uint64 GetContextViewId() const;
349 void IncrementContextViewId();
350
351 // TODO(nareshmodi): Encapsulate remote state into a separate
352 // class/struct.
353 //
354 // Enables the eager context to communicate with remote devices. When
355 // initializing with this method, this context will be the primary context,
356 // which will kill all its remote contexts in shutdown.
357 //
358 // - server: A ServerInterface that exports the tensorflow.WorkerService.
359 // Note that this class expects the server to already have been started.
360 // - remote_eager_workers: A cache from which we can get "EagerClient"s to
361 // communicate with remote eager services.
362 // - remote_device_mgr: A DeviceMgr* which contains all remote devices
363 // (should contain no local devices).
364 // - remote_contexts: A vector containing task names.
365 Status InitializeRemoteMaster(
366 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
367 std::shared_ptr<WorkerSession> worker_session,
368 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
369 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
370 const std::vector<string>& remote_contexts, uint64 context_id,
371 Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs,
372 DistributedFunctionLibraryRuntime* cluster_flr,
373 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
374 remote_mgr);
375
376 // Update an existing master context with a new set of remote workers (i.e., a
377 // new "view" of cluster membership. Similar to InitializeRemoteMaster but
378 // this will keep the current context_id and increment a context_view_id, will
379 // keep the current resource manager so that resources from the previous view
380 // can still be accessed, and will automatically register existing functions
381 // if there are newly added hosts.
382 Status UpdateRemoteMaster(
383 uint64 context_id,
384 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
385 const std::vector<string>& add_remote_contexts,
386 const std::vector<string>& remove_remote_contexts);
387
388 // Similar with InitializeRemoteMaster but this context will not kill remote
389 // contexts in shutdown.
390 Status InitializeRemoteWorker(
391 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
392 DynamicDeviceMgr* remote_device_mgr,
393 const std::vector<string>& remote_contexts, uint64 context_id,
394 uint64 context_view_id,
395 std::function<Rendezvous*(const int64)> rendezvous_creator,
396 DistributedFunctionLibraryRuntime* cluster_flr,
397 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
398 remote_mgr,
399 std::function<void()> resource_deallocator);
400
401 // Similar with InitializeRemoteWorker but will reuse existing context and
402 // increment context_view_id.
403 Status UpdateRemoteWorker(
404 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
405 const std::vector<string>& remote_contexts, uint64 context_id);
406
407 Status StoreCollectiveOpsServer(
408 std::unique_ptr<ServerInterface> new_server, const DeviceMgr* device_mgr,
409 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
410
411 // For the specified remote worker, preprocess and set its device filters.
412 Status SetRemoteDeviceFilters(const string& remote_worker,
413 const std::vector<string>& device_filters);
414
415 // For the specified remote worker, apply the stored device filters to the
416 // list of device attributes following these rules:
417 // (1) if the remote worker does not have device filters, all devices are
418 // visible to the worker;
419 // (2) if the device is on the remote worker, then it is visible;
420 // (3) if the device matches at least one device filter, then it is visible.
421 // The result is saved as a boolean vector of the same length (i.e.,
422 // filtered_device_mask) indicating whether each of the devices is visible to
423 // the remote worker.
424 void FilterDevicesForRemoteWorkers(
425 const string& remote_worker,
426 const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
427 std::vector<bool>* filtered_device_mask);
428
429 // TODO(fishx): Remove the custom deleter once we remove forward declaration.
430 const std::unique_ptr<eager::RemoteMgr,
431 std::function<void(eager::RemoteMgr*)>>&
RemoteMgr()432 RemoteMgr() {
433 return remote_mgr_;
434 }
435
436 // If true, then tensors should be shipped across processes via the
437 // EagerService.Enqueue(SendTensorOp). If false, _Send/_Recv ops should be
438 // used instead (which in-turn use WorkerService.RecvTensor RPCs).
UseSendTensorRPC()439 bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
440
GetServer()441 tensorflow::ServerInterface* GetServer() { return server_.get(); }
442
443 // For LLVM style RTTI.
classof(const AbstractContext * ptr)444 static bool classof(const AbstractContext* ptr) {
445 return ptr->getKind() == kEager;
446 }
447
448 // Function to support distributed C API.
SetDistributedManager(std::unique_ptr<ImmediateExecutionDistributedManager> distributed)449 void SetDistributedManager(
450 std::unique_ptr<ImmediateExecutionDistributedManager> distributed)
451 override {
452 distributed_manager_ = std::move(distributed);
453 }
GetDistributedManager()454 ImmediateExecutionDistributedManager* GetDistributedManager() override {
455 return distributed_manager_.get();
456 }
457 #endif // IS_MOBILE_PLATFORM
458
459 // Closes remote eager contexts, waits for all RPCs to finish, and
460 // destroys the EagerClientCache. No RPCs can be made through this context
461 // after this method has been called.
462 // This method exists to aid a clean shutdown. It causes all RPCs to finish
463 // and remote TensorHandles to release their references to this context.
464 // To avoid deadlocks, this method must not be called on the thread
465 // processing RPCs because it makes RPCs and waits for their completion.
466 //
467 // On mobile, it just cleans the caches.
468 void WaitForAndCloseRemoteContexts();
469
PinSmallOpsToCPU()470 bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; }
471
TFEnv()472 tensorflow::Env* TFEnv() const { return env_; }
473
474 Status FindDeviceFromName(const char* device_name, Device** device) const;
475
476 Status FindCompositeDeviceFromName(StringPiece device_name,
477 CompositeDevice** device) const;
478
479 Status RegisterCustomDevice(const string& name,
480 std::unique_ptr<CustomDevice> device) override;
481
GetCustomDeviceOpHandler()482 CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
483 return custom_device_op_handler_;
484 };
485
486 // Find or create a composite device with the given `underlying_devices` and
487 // `device_name` (if not empty).
488 Status FindOrCreateCompositeDevice(
489 const std::vector<string>& underlying_devices, const string& device_name,
490 CompositeDevice** composite_device);
491
492 bool OnSameTask(const Device* first, const Device* second) const;
493 // Gets the CPU device on the task of device.
494 Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
495
session_options()496 const SessionOptions& session_options() const { return opts_; }
497 void InitPrioritizedDeviceTypeList();
498
499 private:
CreateRendezvous(int64 step_id)500 Rendezvous* CreateRendezvous(int64 step_id) const {
501 if (rendezvous_creator_ != nullptr) {
502 return rendezvous_creator_(step_id);
503 }
504
505 #if !defined(IS_MOBILE_PLATFORM)
506 if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) {
507 auto* remote_r = worker_env_->rendezvous_mgr->Find(step_id);
508 remote_r->Initialize(worker_session_.get()).IgnoreError();
509 return remote_r;
510 }
511 #endif
512
513 if (remote_device_mgr() == nullptr) {
514 return new IntraProcessRendezvous(local_device_mgr());
515 }
516
517 return nullptr;
518 }
519
520 ~EagerContext() override;
521
522 Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
523 Status RegisterExistingFunctionsOnRemoteWorkers(
524 const std::vector<string>& remote_workers);
525
526 void ResetPFLR(const DeviceMgr* device_mgr, Env* env,
527 const ConfigProto* config, int graph_def_version,
528 const FunctionLibraryDefinition* lib_def,
529 const OptimizerOptions& optimizer_options,
530 thread::ThreadPool* thread_pool = nullptr,
531 DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
532
533 void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr);
534
535 void ClearResourceContainer(const string& name);
536
537 template <typename T>
538 struct OwnedOrUnownedHelper {
539 public:
OwnedOrUnownedHelperOwnedOrUnownedHelper540 OwnedOrUnownedHelper() {}
541 explicit OwnedOrUnownedHelper(T* object, const bool owned = false) {
542 Reset(object, owned);
543 }
544
ResetOwnedOrUnownedHelper545 void Reset(std::unique_ptr<T> object) {
546 owned_object = std::move(object);
547 unowned_object_ptr = nullptr;
548 }
549
550 void Reset(T* object, const bool owned = false) {
551 if (owned) {
552 owned_object.reset(object);
553 unowned_object_ptr = nullptr;
554 } else {
555 owned_object.reset(nullptr);
556 unowned_object_ptr = object;
557 }
558 }
559
OwnedOwnedOrUnownedHelper560 bool Owned() const { return owned_object != nullptr; }
561
GetOwnedOwnedOrUnownedHelper562 T* GetOwned() const { return owned_object.get(); }
GetOwnedOrUnownedHelper563 T* Get() const {
564 return owned_object ? owned_object.get() : unowned_object_ptr;
565 }
566
567 std::unique_ptr<T> owned_object = nullptr;
568 T* unowned_object_ptr = nullptr;
569 };
570
571 SessionOptions opts_;
572 const ContextDevicePlacementPolicy default_device_placement_policy_;
573
574 // Note: we cannot use C++11 thread_local here as there is no concept of a
575 // thread-local-object-local variable in C++11.
576 mutable mutex policy_map_mu_;
577 std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
578 device_placement_policy_ TF_GUARDED_BY(policy_map_mu_);
579
580 OwnedOrUnownedHelper<const DeviceMgr> local_device_manager_;
581 // Maintain copy of all previously created local device managers.
582 std::vector<std::unique_ptr<const DeviceMgr>> old_local_device_managers_;
583
584 // Unowned DynamicDeviceMgr is set on remote worker to allow running
585 // multi-device function on remote worker.
586 OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
587
588 Device* host_cpu_device_; // Owned by device_manager
589 mutable mutex device_type_list_mu_;
590 std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list_
591 TF_GUARDED_BY(device_type_list_mu_);
592 Rendezvous* rendezvous_;
593 std::function<Rendezvous*(const int64)> rendezvous_creator_;
594 CustomDeviceOpHandler custom_device_op_handler_;
595
596 mutable mutex composite_devices_mu_;
597 // Maps from the fingerprint of a set of device names to a virtual
598 // CompositeDevice.
599 // TODO(b/145922293): Consider taking device names as keys.
600 absl::flat_hash_map<uint64, std::unique_ptr<CompositeDevice>>
601 composite_devices_ ABSL_GUARDED_BY(composite_devices_mu_);
602
Global()603 FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
604
605 std::unique_ptr<thread::ThreadPool> thread_pool_;
606
607 // EagerContext owns the DistributedFunctionLibraryRuntime(
608 // EagerClusterFunctionLibraryRuntime) if using EagerService for remote
609 // function execution (lazy_copy_function_remote_inputs_=true).
610 OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_;
611 // One FunctionLibraryRuntime per device.
612 // func_libs[i] is the FunctionLibraryRuntime corresponding to
613 // session->devices[i].
614 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
615
616 std::function<void(std::function<void()>)> runner_;
617
618 mutex cache_mu_;
619 struct RegisteredFunction : public core::RefCounted {
~RegisteredFunctionRegisteredFunction620 ~RegisteredFunction() override {}
621
622 std::unique_ptr<std::vector<Fprint128>> cached_kernel_keys;
623 };
624 std::unordered_map<Fprint128, core::RefCountPtr<KernelAndDevice>,
625 Fprint128Hasher>
626 kernel_cache_ TF_GUARDED_BY(cache_mu_);
627 std::unordered_map<string, RegisteredFunction*> registered_functions_
628 TF_GUARDED_BY(cache_mu_);
629
630 // Whether we should compute RunMetadata.
631 std::atomic<bool> should_store_graphs_{false};
632 mutex metadata_mu_;
633 std::unique_ptr<RunMetadata> run_metadata_ TF_GUARDED_BY(metadata_mu_);
634 GraphCollector graph_collector_;
635 std::atomic<bool> log_device_placement_;
636 std::atomic<bool> allow_soft_placement_;
637
638 // Information related to step containers.
639 std::atomic<int> num_active_steps_;
640 std::unique_ptr<ScopedStepContainer> step_container_
641 TF_GUARDED_BY(metadata_mu_);
642
643 EagerExecutor default_executor_;
644 mutable mutex executor_map_mu_;
645 // Not owned.
646 std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
647 TF_GUARDED_BY(executor_map_mu_);
648 std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>>
649 has_cleanup_ TF_GUARDED_BY(executor_map_mu_);
650
651 const bool log_memory_;
652
653 // Whether to use same rendezvous instance across function/eager executions.
654 bool reuse_rendezvous_for_functions_ = false;
655
656 Env* const env_;
657
658 OwnedOrUnownedHelper<CollectiveExecutorMgrInterface> collective_executor_mgr_;
659
660 #if !defined(IS_MOBILE_PLATFORM)
661 std::vector<string> GetRemoteContexts() TF_LOCKS_EXCLUDED(remote_state_mu_);
662 bool IsRemoteContextsEmpty() TF_LOCKS_EXCLUDED(remote_state_mu_);
663 void CloseAndClearAllRemoteContexts();
664 void CloseRemoteContexts(const std::vector<string>& remote_contexts,
665 uint64 context_id, uint64 context_view_id);
666
667 Status SetMasterContextState(
668 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
669 std::shared_ptr<WorkerSession> worker_session,
670 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
671 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
672 uint64 context_id, uint64 context_view_id, Rendezvous* r,
673 const DeviceMgr* local_device_mgr, int keep_alive_secs,
674 DistributedFunctionLibraryRuntime* cluster_flr,
675 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
676 remote_mgr);
677
678 // The server_ is not const since we release it when the context is destroyed.
679 // Therefore the server_ object is not marked as const (even though it should
680 // be).
681 std::unique_ptr<ServerInterface> server_;
682 WorkerEnv* worker_env_ = nullptr;
683 std::shared_ptr<WorkerSession> worker_session_;
684
685 mutable mutex remote_state_mu_;
686
687 uint64 context_id_ TF_GUARDED_BY(remote_state_mu_);
688 // The view id of an eager context should be set to 0 when context is created,
689 // and continuously incremented when context with the same context_id gets
690 // updated. The view id should be consistent between master and workers.
691 uint64 context_view_id_ TF_GUARDED_BY(remote_state_mu_);
692 std::vector<string> remote_contexts_ TF_GUARDED_BY(remote_state_mu_);
693 std::unique_ptr<eager::EagerClientCache> remote_eager_workers_
694 TF_GUARDED_BY(remote_state_mu_);
695
696 int keep_alive_secs_ TF_GUARDED_BY(remote_state_mu_);
697 std::atomic<int> sleep_for_secs_;
698
699 std::unique_ptr<Thread> keep_alive_thread_;
700 mutex keep_alive_thread_shutdown_mu_;
701 condition_variable keep_alive_thread_cv_;
702 bool shutting_down_ TF_GUARDED_BY(keep_alive_thread_shutdown_mu_) = false;
703
704 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
705 remote_mgr_;
706 bool is_master_ TF_GUARDED_BY(remote_state_mu_);
707
708 // Maps from a remote worker to a list of parsed device filters.
709 std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>>
710 cluster_device_filters_ TF_GUARDED_BY(remote_state_mu_);
711
712 // A distributed manager that helps setup, update, and check liveness of
713 // member tasks in the cluster.
714 std::unique_ptr<ImmediateExecutionDistributedManager> distributed_manager_;
715
716 #endif // IS_MOBILE_PLATFORM
717
718 // For a multi device function, the target device of each input is unknown
719 // until the function is instantiated on the default function device.
720 // If false, eagerly copy all remote inputs to the default function device;
721 // if true, lazily copy remote inputs to their target devices to avoid
722 // redundant copies.
723 bool lazy_copy_function_remote_inputs_ = false;
724 bool use_send_tensor_rpc_;
725 const bool pin_small_ops_to_cpu_;
726
727 // Function that will be invoked in destructor to deallocate resources related
728 // to this context.
729 std::function<void()> resource_deallocator_ = nullptr;
730 };
731
ContextFromInterface(ImmediateExecutionContext * context)732 inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) {
733 return down_cast<EagerContext*>(context);
734 }
735
736 namespace internal {
737 struct EagerContextDeleter {
operatorEagerContextDeleter738 void operator()(EagerContext* p) const {
739 if (p != nullptr) {
740 p->Release();
741 }
742 }
743 };
744 } // namespace internal
745
746 using EagerContextPtr =
747 std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
748
749 } // namespace tensorflow
750
751 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
752