1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <unordered_set>
25 #include <vector>
26
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/c/eager/immediate_execution_context.h"
30 #include "tensorflow/core/common_runtime/composite_device.h"
31 #include "tensorflow/core/common_runtime/device_factory.h"
32 #include "tensorflow/core/common_runtime/device_mgr.h"
33 #include "tensorflow/core/common_runtime/eager/custom_device.h"
34 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
35 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/example/example.pb.h"
41 #include "tensorflow/core/framework/collective.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/log_memory.h"
44 #include "tensorflow/core/framework/rendezvous.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/core/threadpool.h"
49 #include "tensorflow/core/lib/gtl/flatmap.h"
50 #include "tensorflow/core/lib/gtl/flatset.h"
51 #include "tensorflow/core/lib/gtl/inlined_vector.h"
52 #include "tensorflow/core/lib/gtl/map_util.h"
53 #include "tensorflow/core/platform/casts.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/fingerprint.h"
56 #include "tensorflow/core/platform/mutex.h"
57 #include "tensorflow/core/platform/platform.h"
58 #include "tensorflow/core/platform/status.h"
59 #include "tensorflow/core/platform/thread_annotations.h"
60 #include "tensorflow/core/platform/threadpool.h"
61 #include "tensorflow/core/public/session_options.h"
62 #include "tensorflow/core/public/version.h"
63 #include "tensorflow/core/util/device_name_utils.h"
64
65 // "tensorflow/core/platform/platform.h" must be included first before using
66 // IS_MOBILE_PLATFORM.
67 #if !defined(IS_MOBILE_PLATFORM)
68 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
69 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
70 #include "tensorflow/core/distributed_runtime/server_lib.h"
71 #include "tensorflow/core/distributed_runtime/worker_cache.h"
72 #include "tensorflow/core/distributed_runtime/worker_env.h"
73 #endif // !IS_MOBILE_PLATFORM
74
75 namespace tensorflow {
76
77 namespace eager {
78 // We need this forward declaration because we have circular dependency:
79 // Context -> RemoteMgr -> TensorHandle -> Context.
80 // TODO(fishx): Remove this once we remove Context dependency in TensorHandle.
81 class RemoteMgr;
82 } // namespace eager
83
84 class TensorHandle;
85 class EagerOperation;
86
87 class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
88 public:
89 static constexpr uint64 kInvalidContextId = 0;
90
NewContextId()91 static uint64 NewContextId() {
92 uint64 context_id = random::New64();
93 while (context_id == kInvalidContextId) {
94 context_id = random::New64();
95 }
96 return context_id;
97 }
98
99 EagerContext(
100 const SessionOptions& opts,
101 ContextDevicePlacementPolicy default_device_placement_policy, bool async,
102 /*const*/ DeviceMgr* device_mgr, bool device_mgr_owned,
103 /*const*/ Rendezvous* rendezvous,
104 DistributedFunctionLibraryRuntime* cluster_flr = nullptr,
105 CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr,
106 bool run_eager_op_as_function = false, bool jit_compile_rewrite = false);
107
Release()108 void Release() override { Unref(); }
109
110 AbstractTensorInterface* CreateInt64Scalar(int64_t value) override;
111 AbstractTensorInterface* CreateUint64Scalar(uint64 value) override;
112 AbstractTensorInterface* CreateInt32Scalar(int32_t value) override;
113 AbstractTensorInterface* CreateFloatScalar(float value) override;
114 AbstractTensorInterface* CreateDoubleScalar(double value) override;
115 AbstractTensorInterface* CreateHalfScalar(Eigen::half value) override;
116 AbstractTensorInterface* CreateStringScalar(
117 tensorflow::tstring value) override;
118 AbstractTensorInterface* CreateComplex128Scalar(
119 tensorflow::complex128 value) override;
120 AbstractTensorInterface* CreateBoolScalar(bool value) override;
121
122 AbstractTensorInterface* CreateTensor(
123 DataType dtype, absl::Span<const int64_t> dim_sizes) override;
124 AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims,
125 int num_dims, void* data, size_t len,
126 MemoryReleaser memory_releaser,
127 void* memory_releaser_arg) override;
128
129 ImmediateExecutionTensorHandle* CreateLocalHandle(
130 AbstractTensorInterface* t) override;
131 // Create an abstract tensor handle from tensorflow::Tensor.
132 ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
133 tensorflow::Tensor& t, const char* d_name) override;
134 ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
135 ImmediateExecutionTensorHandle* handle, const char* device_name,
136 Status* status) override;
137 ImmediateExecutionOperation* CreateOperation() override;
138
139 // This is a virtual helper function to convert TFRT TensorHandle to
140 // tensorflow::TensorHandle. In current runtime EagerContext, just forward
141 // the input since the input tensor handle is already a
142 // tensorflow::TensorHandle.
143 ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
144 ImmediateExecutionTensorHandle* handle) override;
145
146 Status RegisterFunction(AbstractFunction* f) override;
147
148 bool UsesTFRT() override;
149
150 bool RunEagerOpAsFunction() const;
151
152 void SetRunEagerOpAsFunction(bool enable) override;
153
154 bool JitCompileRewrite() const;
155
156 void SetJitCompileRewrite(bool enable) override;
157
158 void ListDevices(std::vector<DeviceAttributes>* devices) override;
159
160 Status AddDevices(std::vector<std::unique_ptr<Device>> devices) override;
161
GetThreadPool()162 thread::ThreadPool* GetThreadPool() { return thread_pool_.get(); }
163
164 // Returns the function library runtime for the given device.
func_lib(const Device * d)165 FunctionLibraryRuntime* func_lib(const Device* d) const {
166 return pflr_->GetFLR(d->name());
167 }
168
pflr()169 ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); }
170
runner()171 std::function<void(std::function<void()>)>* runner() { return &runner_; }
172
173 // Specify a executor for this thread.
174 void SetExecutorForThread(EagerExecutor* executor) override;
175
prioritized_device_type_list()176 const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
177 const {
178 mutex_lock l(device_type_list_mu_);
179 return prioritized_device_type_list_;
180 }
181
182 // Clear pending nodes in thread executors and kernel caches.
183 void ClearCachesAndThreadExecutors() override;
184 // Clear pending nodes in default executor and kernel caches.
185 void ClearCachesAndDefaultExecutor();
186
187 // Sets the device placement policy for the current thread.
188 void SetThreadLocalDevicePlacementPolicy(
189 ContextDevicePlacementPolicy policy) override;
190
191 // Returns the device placement policy for the current thread.
192 ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override;
193
194 // Select an appropriate device for an operation.
195 //
196 // Given the preferred device for the operation, and the node_def, finds the
197 // best suitable device for the operation in this context.
198 //
199 // The preferred device is specified as a `ParsedName` containing the elements
200 // (details) that the resulting device should match. If there are no such
201 // devices, and the context currently allows soft device placement, a suitable
202 // device not matching `preferred` will be chosen.
203 //
204 // The chosen device is stored in the `device` argument. The argument is not
205 // modified unless this method returns `Status::OK()`.
206 Status SelectDevice(DeviceNameUtils::ParsedName preferred,
207 const NodeDef& ndef, Device** out) const;
208
209 // TODO(mdan): Rename to ContainsFunction.
210 bool FindFunctionByName(const string& name) const;
211
212 Status FindFunctionOpData(const string& name,
213 const tensorflow::OpRegistrationData** op_data);
214
215 const FunctionDef* FindFunctionDef(const string& name) const override;
216
HostCPU()217 Device* HostCPU() const { return host_cpu_device_; }
CanonicalDevice(Device * d)218 Device* CanonicalDevice(Device* d) const {
219 return HostCPU() == d ? nullptr : d;
220 }
HostCPUParsedName()221 const DeviceNameUtils::ParsedName& HostCPUParsedName() const override {
222 return HostCPU()->parsed_name();
223 }
224
HostCPUName()225 const string& HostCPUName() const override { return HostCPU()->name(); }
226
GetGraphCollector()227 GraphCollector* GetGraphCollector() { return &graph_collector_; }
228
229 EagerExecutor& Executor() override;
230
231 // Add the given `fdef` to the local FunctionLibraryDefinition. And add an
232 // entry to the KernelAndDevice cache for it if it's not exist.
233 Status AddFunctionDef(const FunctionDef& fdef) override;
234
235 Status AddFunctionDefWithStackTraces(
236 const FunctionDef& fdef, const StackTracesMap& stack_traces) override;
237
238 // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
239 // it to the local FunctionLibraryDefinition as well, but no need to add it
240 // to the KernelAndDevice cache since they won't be executed as
241 // KernelAndDevices.
242 Status AddFunctionDef(const FunctionDef& fdef,
243 const FunctionDefLibrary& library,
244 bool add_to_local_only = false,
245 const StackTracesMap& stack_traces = {});
246
247 const FunctionDef* GetFunctionDef(const string& function_name);
248
249 std::vector<string> ListFunctionNames() override;
250
251 Status RemoveFunction(const string& func) override;
252
253 // Wait for pending nodes to be finished in local executors (including context
254 // default executor and thread executors) and executors on remote workers.
255 // Return combined status of remote executors. If there are multiple errors,
256 // the Status code will be the same as the first remote executor that has
257 // errors, and the error message will be combined from all executors.
258 Status SyncExecutors();
259
AsyncWait()260 Status AsyncWait() override { return SyncExecutors(); }
261
262 core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
263 Device* GetCachedDevice(Fprint128 device_cache_key);
264
265 void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
266 void AddDeviceToCache(Fprint128 device_cache_key, Device* device);
267
LogDevicePlacement()268 bool LogDevicePlacement() const { return log_device_placement_; }
SetLogDevicePlacement(bool enable)269 void SetLogDevicePlacement(bool enable) override {
270 log_device_placement_ = enable;
271 }
272
273 // When tensor transfer across functions/eager executions using send/recv ops
274 // are required, `reuse_rendezvous_for_functions_` can be set to true so that
275 // function executions and eager executions use the same rendezvous instance,
276 // instead of creating new instance per function calls.
SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions)277 void SetReuseRendezvousForFunctions(
278 bool reuse_rendezvous_for_functions) override {
279 reuse_rendezvous_for_functions_ = reuse_rendezvous_for_functions;
280 }
GetReuseRendezvousForFunctions()281 bool GetReuseRendezvousForFunctions() const {
282 return reuse_rendezvous_for_functions_;
283 }
reuse_rendezvous_for_functions_mu()284 mutex* reuse_rendezvous_for_functions_mu() {
285 return &reuse_rendezvous_for_functions_mu_;
286 }
287
AllowSoftPlacement()288 bool AllowSoftPlacement() const { return allow_soft_placement_; }
SetAllowSoftPlacement(bool enable)289 void SetAllowSoftPlacement(bool enable) override {
290 allow_soft_placement_ = enable;
291 }
LogMemory()292 bool LogMemory() const { return log_memory_; }
293
GetRendezvous()294 Rendezvous* GetRendezvous() const { return rendezvous_; }
295
ResetGlobalRendezvousForFunction()296 void ResetGlobalRendezvousForFunction() override {
297 mutex_lock l(global_rendezvous_mu_);
298 // Remove the global rendezvous instance from the local rendezvous table
299 // if it uses local rendezvous type, which forces EagerContext to create a
300 // new local rendezvous instance in the table.
301 local_rendezvous_table_->Remove(-1);
302 global_rendezvous_for_functions_ =
303 core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
304 }
305
306 // Returns the global_rendezvous_for_functions' underlying LocalRendezvous'
307 // status. If the underlying Rendezvous is not in the local_rendezvous_table_
308 // returns OK.
309 Status GetGlobalRendezvousForFunctionLocalRendezvousStatus();
310
311 // Returns a function which maps from step_id to rendezvous. This closure
312 // respects the value of `SetReuseRendezvousForFunctions` at the time the
313 // closure was created, which allows the setting to be toggled around async op
314 // launches.
315 //
316 // The caller of the returned function owns a reference to the resulting
317 // Rendezvous.
RendezvousCreator()318 std::function<Rendezvous*(int64_t)> RendezvousCreator() {
319 // There is an implicit assumption that the global_rendezvous_for_functions_
320 // is always an IntraProcessRendezvous to match the behaviour of the
321 // EagerContext's rendezvous.
322 // Ref: tensorflow/c/eager/c_api.cc;l=143;rcl=396387348
323 // If a cross process kernel needs a rendezvous a new InterProcessRendezvous
324 // should be created.
325 if (reuse_rendezvous_for_functions_ && rendezvous_creator_ == nullptr &&
326 #if !defined(IS_MOBILE_PLATFORM)
327 worker_env_ == nullptr &&
328 #endif
329 remote_device_mgr() == nullptr) {
330 return [this](int64_t step_id) {
331 mutex_lock l(global_rendezvous_mu_);
332 global_rendezvous_for_functions_->Ref();
333 return global_rendezvous_for_functions_.get();
334 };
335 } else {
336 return [this](int64_t step_id) { return CreateRendezvous(step_id); };
337 }
338 }
339
collective_executor_mgr()340 CollectiveExecutorMgrInterface* collective_executor_mgr() {
341 return collective_executor_mgr_.Get();
342 }
GetCollectiveExecutorHandle()343 std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
344 return std::unique_ptr<CollectiveExecutor::Handle>(
345 new CollectiveExecutor::Handle(
346 collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/));
347 }
348
local_device_mgr()349 tensorflow::DeviceMgr* local_device_mgr() const {
350 return local_device_manager_.Get();
351 }
remote_device_mgr()352 const tensorflow::DynamicDeviceMgr* remote_device_mgr() const {
353 return remote_device_manager_.Get();
354 }
355
GetOwnedRemoteDeviceMgr()356 tensorflow::DynamicDeviceMgr* GetOwnedRemoteDeviceMgr() {
357 return remote_device_manager_.GetOwned();
358 }
359
ListLocalTfDevices()360 std::vector<Device*> ListLocalTfDevices() override {
361 return local_device_mgr()->ListDevices();
362 }
363
364 std::vector<Device*> ListAllTfDevices() override;
365
366 // TODO(apassos) clean up RunMetadata storage.
MetadataMu()367 mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
368 bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
369 void SetShouldStoreGraphs(bool value) override;
RunMetadataProto()370 RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) {
371 return run_metadata_.get();
372 }
373 std::unique_ptr<RunMetadata> ExportRunMetadata() override
374 TF_LOCKS_EXCLUDED(metadata_mu_);
375
376 void StartStep() override;
377 void EndStep() override;
378 ScopedStepContainer* StepContainer();
379
FuncLibDef()380 FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
381
382 #if !defined(IS_MOBILE_PLATFORM)
383 // Assign the EagerClient pointer to `client` based on the given device / task
384 // name, and increment the refcount of the client. The reference ownership is
385 // transferred to the caller, and the unref should automatically happen when
386 // destructing the RefCountPtr object at the caller's side.
387 // `client` must not be initialized or holding a reference of another object
388 // before calling this method.
389 Status GetClient(Device* device,
390 core::RefCountPtr<eager::EagerClient>* client);
391 Status GetClient(const DeviceNameUtils::ParsedName& device_name,
392 core::RefCountPtr<eager::EagerClient>* client);
393 Status GetClient(const string& remote_task,
394 core::RefCountPtr<eager::EagerClient>* client);
395
396 uint64 GetContextId() const;
397 uint64 GetContextViewId() const;
398 void IncrementContextViewId();
399
400 Status EnableCollectiveOps(const ServerDef& server_def) override;
401
402 // TODO(nareshmodi): Encapsulate remote state into a separate
403 // class/struct.
404 //
405 // Enables the eager context to communicate with remote devices. When
406 // initializing with this method, this context will be the primary context,
407 // which will kill all its remote contexts in shutdown.
408 //
409 // - server: A ServerInterface that exports the tensorflow.WorkerService.
410 // Note that this class expects the server to already have been started.
411 // - remote_eager_workers: A cache from which we can get "EagerClient"s to
412 // communicate with remote eager services.
413 // - remote_device_mgr: A DeviceMgr* which contains all remote devices
414 // (should contain no local devices).
415 // - remote_contexts: A vector containing task names.
416 // TODO(b/184375824): clean up parameter order for better readability.
417 Status InitializeRemoteMaster(
418 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
419 std::shared_ptr<WorkerSession> worker_session,
420 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
421 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
422 const std::vector<string>& remote_contexts, uint64 context_id,
423 /*const*/ Rendezvous* r, /*const*/ DeviceMgr* local_device_mgr,
424 int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
425 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
426 remote_mgr);
427
428 // Update an existing master context with a new set of remote workers (i.e., a
429 // new "view" of cluster membership. Similar to InitializeRemoteMaster but
430 // this will keep the current context_id and increment a context_view_id, will
431 // keep the current resource manager so that resources from the previous view
432 // can still be accessed, and will automatically register existing functions
433 // if there are newly added hosts.
434 Status UpdateRemoteMaster(
435 uint64 context_id,
436 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
437 const std::vector<string>& add_remote_contexts,
438 const std::vector<string>& remove_remote_contexts);
439
440 // Similar with InitializeRemoteMaster but this context will not kill remote
441 // contexts in shutdown.
442 Status InitializeRemoteWorker(
443 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
444 DynamicDeviceMgr* remote_device_mgr,
445 const std::vector<string>& remote_contexts, uint64 context_id,
446 uint64 context_view_id,
447 std::function<Rendezvous*(const int64_t)> rendezvous_creator,
448 DistributedFunctionLibraryRuntime* cluster_flr,
449 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
450 remote_mgr,
451 std::function<void()> resource_deallocator);
452
453 // Similar with InitializeRemoteWorker but will reuse existing context and
454 // increment context_view_id.
455 Status UpdateRemoteWorker(
456 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
457 const std::vector<string>& remote_contexts, uint64 context_id);
458
459 Status StoreCollectiveOpsServer(
460 std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
461 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
462
463 // For the specified remote worker, preprocess and set its device filters.
464 Status SetRemoteDeviceFilters(const string& remote_worker,
465 const std::vector<string>& device_filters);
466
467 // For the specified remote worker, apply the stored device filters to the
468 // list of device attributes following these rules:
469 // (1) if the remote worker does not have device filters, all devices are
470 // visible to the worker;
471 // (2) if the device is on the remote worker, then it is visible;
472 // (3) if the device matches at least one device filter, then it is visible.
473 // The result is saved as a boolean vector of the same length (i.e.,
474 // filtered_device_mask) indicating whether each of the devices is visible to
475 // the remote worker.
476 void FilterDevicesForRemoteWorkers(
477 const string& remote_worker,
478 const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
479 std::vector<bool>* filtered_device_mask);
480
481 // TODO(fishx): Remove the custom deleter once we remove forward declaration.
482 const std::unique_ptr<eager::RemoteMgr,
483 std::function<void(eager::RemoteMgr*)>>&
RemoteMgr()484 RemoteMgr() {
485 return remote_mgr_;
486 }
487
488 // If true, then tensors should be shipped across processes via the
489 // EagerService.Enqueue(SendTensorOp). If false, _Send/_Recv ops should be
490 // used instead (which in-turn use WorkerService.RecvTensor RPCs).
UseSendTensorRPC()491 bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
492
GetServer()493 tensorflow::ServerInterface* GetServer() { return server_.get(); }
494
495 // For LLVM style RTTI.
classof(const AbstractContext * ptr)496 static bool classof(const AbstractContext* ptr) {
497 return ptr->getKind() == kEager;
498 }
499
500 // Function to support distributed C API.
SetDistributedManager(std::unique_ptr<ImmediateExecutionDistributedManager> distributed)501 void SetDistributedManager(
502 std::unique_ptr<ImmediateExecutionDistributedManager> distributed)
503 override {
504 distributed_manager_ = std::move(distributed);
505 }
GetDistributedManager()506 ImmediateExecutionDistributedManager* GetDistributedManager() override {
507 return distributed_manager_.get();
508 }
509
510 // May only be used during multi-client setup so that a RemoteRendezvous
511 // can be initialized instead of defaulting to the IntraProcessRendezvous.
512 void SetWorkerEnv(WorkerEnv* worker_env,
513 std::shared_ptr<WorkerSession> worker_session);
514 #endif // IS_MOBILE_PLATFORM
515
516 // Closes remote eager contexts, waits for all RPCs to finish, and
517 // destroys the EagerClientCache. No RPCs can be made through this context
518 // after this method has been called.
519 // This method exists to aid a clean shutdown. It causes all RPCs to finish
520 // and remote TensorHandles to release their references to this context.
521 // To avoid deadlocks, this method must not be called on the thread
522 // processing RPCs because it makes RPCs and waits for their completion.
523 //
524 // On mobile, it just cleans the caches.
525 void WaitForAndCloseRemoteContexts();
526
PinSmallOpsToCPU()527 bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; }
528
TFEnv()529 tensorflow::Env* TFEnv() const { return env_; }
530
531 Status FindDeviceFromName(const char* device_name, Device** device) const;
532
533 Status FindCompositeDeviceFromName(StringPiece device_name,
534 CompositeDevice** device) const;
535
536 Status RegisterCustomDevice(const string& name,
537 std::unique_ptr<CustomDevice> device) override;
538
GetCustomDeviceOpHandler()539 CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
540 return custom_device_op_handler_;
541 };
542
543 // Find or create a composite device with the given `underlying_devices` and
544 // `device_name` (if not empty).
545 Status FindOrCreateCompositeDevice(
546 const std::vector<string>& underlying_devices, const string& device_name,
547 CompositeDevice** composite_device);
548
549 bool OnSameTask(const Device* first, const Device* second) const;
550 // Gets the CPU device on the task of device.
551 Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
552
session_options()553 const SessionOptions& session_options() const { return opts_; }
554 void InitPrioritizedDeviceTypeList();
555
556 // Re-assign cluster-FLR and re-initialize devices and FLR in process-FLR
557 void UpdateClusterFLRAndInitDevices(
558 DistributedFunctionLibraryRuntime* cluster_flr);
559
560 // A constant representing the step id used for the global rendezvous.
561 // This is used to distibguish whether a user-specified step id should be set.
562 // Step id value of kGlobalRendezvous is reserved and should not be specified
563 // by the user.
564 static const int64_t kGlobalRendezvousId;
565
566 private:
567 // The class for wrapping a map of step_id to local rendezvous instances.
568 class LocalRendezvousTable {
569 public:
570 LocalRendezvousTable() = default;
571 ~LocalRendezvousTable();
572
573 IntraProcessRendezvous* FindOrCreate(int64_t step_id,
574 DeviceMgr* device_mgr);
575 IntraProcessRendezvous* Find(int64_t step_id);
576 void Remove(int64_t step_id);
577 void CleanUpAll();
578
579 private:
580 mutable mutex table_lock_;
581 absl::flat_hash_map<int64_t, IntraProcessRendezvous*> table_
582 TF_GUARDED_BY(table_lock_);
583 };
584
CreateRendezvous(int64_t step_id)585 Rendezvous* CreateRendezvous(int64_t step_id) const {
586 if (rendezvous_creator_ != nullptr) {
587 VLOG(6) << "Creating rendezvous using the rendezvous_creator_.";
588 return rendezvous_creator_(step_id);
589 }
590
591 #if !defined(IS_MOBILE_PLATFORM)
592 if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) {
593 VLOG(6) << "Creating rendezvous using the worker_env's rendezvous_mgr.";
594 auto* remote_r = worker_env_->rendezvous_mgr->Find(step_id);
595 remote_r->Initialize(worker_session_.get()).IgnoreError();
596 return remote_r;
597 }
598 #endif
599
600 if (remote_device_mgr() == nullptr) {
601 VLOG(6) << "Creating rendezvous using local_device_mgr.";
602 return local_rendezvous_table_->FindOrCreate(step_id, local_device_mgr());
603 }
604
605 return nullptr;
606 }
607
608 ~EagerContext() override;
609
610 Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
611 Status RegisterExistingFunctionsOnRemoteWorkers(
612 const std::vector<string>& remote_workers);
613
614 void ResetPFLR(const DeviceMgr* device_mgr, Env* env,
615 const ConfigProto* config, int graph_def_version,
616 const FunctionLibraryDefinition* lib_def,
617 const OptimizerOptions& optimizer_options,
618 thread::ThreadPool* thread_pool = nullptr,
619 DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
620
621 void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr);
622 void UpdateGlobalRendezvousDeviceManager(tensorflow::DeviceMgr* device_mgr);
623
624 void ClearResourceContainer(const string& name);
625
626 template <typename T>
627 struct OwnedOrUnownedHelper {
628 public:
OwnedOrUnownedHelperOwnedOrUnownedHelper629 OwnedOrUnownedHelper() {}
630 explicit OwnedOrUnownedHelper(T* object, const bool owned = false) {
631 Reset(object, owned);
632 }
633
ResetOwnedOrUnownedHelper634 void Reset(std::unique_ptr<T> object) {
635 owned_object = std::move(object);
636 unowned_object_ptr = nullptr;
637 }
638
639 void Reset(T* object, const bool owned = false) {
640 if (owned) {
641 owned_object.reset(object);
642 unowned_object_ptr = nullptr;
643 } else {
644 owned_object.reset(nullptr);
645 unowned_object_ptr = object;
646 }
647 }
648
OwnedOwnedOrUnownedHelper649 bool Owned() const { return owned_object != nullptr; }
650
GetOwnedOwnedOrUnownedHelper651 T* GetOwned() const { return owned_object.get(); }
GetOwnedOrUnownedHelper652 T* Get() const {
653 return owned_object ? owned_object.get() : unowned_object_ptr;
654 }
655
656 std::unique_ptr<T> owned_object = nullptr;
657 T* unowned_object_ptr = nullptr;
658 };
659
660 SessionOptions opts_;
661 const ContextDevicePlacementPolicy default_device_placement_policy_;
662
663 // Note: we cannot use C++11 thread_local here as there is no concept of a
664 // thread-local-object-local variable in C++11.
665 mutable mutex policy_map_mu_;
666 std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
667 device_placement_policy_ TF_GUARDED_BY(policy_map_mu_);
668
669 // This device manager maintains only the local devices on this worker.
670 OwnedOrUnownedHelper<DeviceMgr> local_device_manager_;
671 // Maintain copy of all previously created local device managers.
672 std::vector<std::unique_ptr<DeviceMgr>> old_local_device_managers_;
673
674 // Unowned DynamicDeviceMgr is set on remote worker to allow running
675 // multi-device function on remote worker.
676 // This device manager maintains all the devices (including both local and
677 // remote to this worker) in the cluster.
678 OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
679
680 Device* host_cpu_device_; // Owned by device_manager
681 mutable mutex device_type_list_mu_;
682 std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list_
683 TF_GUARDED_BY(device_type_list_mu_);
684 Rendezvous* rendezvous_;
685 std::function<Rendezvous*(const int64_t)> rendezvous_creator_;
686 CustomDeviceOpHandler custom_device_op_handler_;
687
688 mutable mutex composite_devices_mu_;
689 // Maps from the fingerprint of a set of device names to a virtual
690 // CompositeDevice.
691 // TODO(b/145922293): Consider taking device names as keys.
692 absl::flat_hash_map<uint64, std::unique_ptr<CompositeDevice>>
693 composite_devices_ ABSL_GUARDED_BY(composite_devices_mu_);
694
Global()695 FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
696
697 std::unique_ptr<thread::ThreadPool> thread_pool_;
698
699 // EagerContext owns the DistributedFunctionLibraryRuntime(
700 // EagerClusterFunctionLibraryRuntime) if using EagerService for remote
701 // function execution (lazy_copy_function_remote_inputs_=true).
702 OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_;
703 // One FunctionLibraryRuntime per device.
704 // func_libs[i] is the FunctionLibraryRuntime corresponding to
705 // session->devices[i].
706 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
707
708 std::function<void(std::function<void()>)> runner_;
709
710 mutex cache_mu_;
711 mutex device_cache_mu_;
712 struct RegisteredFunction : public core::RefCounted {
~RegisteredFunctionRegisteredFunction713 ~RegisteredFunction() override {}
714
715 std::unique_ptr<std::vector<Fprint128>> cached_kernel_keys;
716 };
717 std::unordered_map<Fprint128, core::RefCountPtr<KernelAndDevice>,
718 Fprint128Hasher>
719 kernel_cache_ TF_GUARDED_BY(cache_mu_);
720 std::unordered_map<string, RegisteredFunction*> registered_functions_
721 TF_GUARDED_BY(cache_mu_);
722 absl::flat_hash_map<Fprint128, Device*, Fprint128Hasher> device_cache_
723 TF_GUARDED_BY(device_cache_mu_);
724
725 // Whether we should compute RunMetadata.
726 std::atomic<bool> should_store_graphs_{false};
727 mutex metadata_mu_;
728 std::unique_ptr<RunMetadata> run_metadata_ TF_GUARDED_BY(metadata_mu_);
729 GraphCollector graph_collector_;
730 std::atomic<bool> log_device_placement_;
731 std::atomic<bool> allow_soft_placement_;
732
733 // Information related to step containers.
734 std::atomic<int> num_active_steps_;
735 std::unique_ptr<ScopedStepContainer> step_container_
736 TF_GUARDED_BY(metadata_mu_);
737
738 EagerExecutor default_executor_;
739 mutable mutex executor_map_mu_;
740 // Not owned.
741 std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
742 TF_GUARDED_BY(executor_map_mu_);
743 std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>>
744 has_cleanup_ TF_GUARDED_BY(executor_map_mu_);
745
746 const bool log_memory_;
747
748 // The table of local rendezvous instances for intra-process communication.
749 // This make sures only one local rendezvous instance exists per step id.
750 std::unique_ptr<LocalRendezvousTable> local_rendezvous_table_;
751
752 // Whether to use same rendezvous instance across function/eager executions.
753 std::atomic<bool> reuse_rendezvous_for_functions_{false};
754 mutable mutex global_rendezvous_mu_;
755 core::RefCountPtr<Rendezvous> global_rendezvous_for_functions_
756 TF_GUARDED_BY(global_rendezvous_mu_);
757 mutex reuse_rendezvous_for_functions_mu_;
758
759 Env* const env_;
760
761 OwnedOrUnownedHelper<CollectiveExecutorMgrInterface> collective_executor_mgr_;
762
763 #if !defined(IS_MOBILE_PLATFORM)
764 std::vector<string> GetRemoteContexts() TF_LOCKS_EXCLUDED(remote_state_mu_);
765 bool IsRemoteContextsEmpty() TF_LOCKS_EXCLUDED(remote_state_mu_);
766 void CloseAndClearAllRemoteContexts();
767 void CloseRemoteContexts(const std::vector<string>& remote_contexts,
768 uint64 context_id, uint64 context_view_id);
769
770 // TODO(b/184375824): clean up parameter order for better readability.
771 Status SetMasterContextState(
772 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
773 std::shared_ptr<WorkerSession> worker_session,
774 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
775 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
776 uint64 context_id, uint64 context_view_id, /*const*/ Rendezvous* r,
777 /*const*/ DeviceMgr* local_device_mgr, int keep_alive_secs,
778 DistributedFunctionLibraryRuntime* cluster_flr,
779 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
780 remote_mgr);
781
782 // The server_ is not const since we release it when the context is destroyed.
783 // Therefore the server_ object is not marked as const (even though it should
784 // be).
785 std::unique_ptr<ServerInterface> server_;
786 WorkerEnv* worker_env_ = nullptr;
787 std::shared_ptr<WorkerSession> worker_session_;
788
789 mutable mutex remote_state_mu_;
790
791 uint64 context_id_ TF_GUARDED_BY(remote_state_mu_);
792 // The view id of an eager context should be set to 0 when context is created,
793 // and continuously incremented when context with the same context_id gets
794 // updated. The view id should be consistent between master and workers.
795 uint64 context_view_id_ TF_GUARDED_BY(remote_state_mu_);
796 std::vector<string> remote_contexts_ TF_GUARDED_BY(remote_state_mu_);
797 std::unique_ptr<eager::EagerClientCache> remote_eager_workers_
798 TF_GUARDED_BY(remote_state_mu_);
799
800 int keep_alive_secs_ TF_GUARDED_BY(remote_state_mu_);
801 std::atomic<int> sleep_for_secs_;
802
803 std::unique_ptr<Thread> keep_alive_thread_;
804 mutex keep_alive_thread_shutdown_mu_;
805 condition_variable keep_alive_thread_cv_;
806 bool shutting_down_ TF_GUARDED_BY(keep_alive_thread_shutdown_mu_) = false;
807
808 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
809 remote_mgr_;
810 bool is_master_ TF_GUARDED_BY(remote_state_mu_);
811
812 // Maps from a remote worker to a list of parsed device filters.
813 std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>>
814 cluster_device_filters_ TF_GUARDED_BY(remote_state_mu_);
815
816 // A distributed manager that helps setup, update, and check liveness of
817 // member tasks in the cluster.
818 std::unique_ptr<ImmediateExecutionDistributedManager> distributed_manager_;
819
820 #endif // IS_MOBILE_PLATFORM
821
822 // For a multi device function, the target device of each input is unknown
823 // until the function is instantiated on the default function device.
824 // If false, eagerly copy all remote inputs to the default function device;
825 // if true, lazily copy remote inputs to their target devices to avoid
826 // redundant copies.
827 bool lazy_copy_function_remote_inputs_ = false;
828 bool use_send_tensor_rpc_;
829 const bool pin_small_ops_to_cpu_;
830
831 // Function that will be invoked in destructor to deallocate resources related
832 // to this context.
833 std::function<void()> resource_deallocator_ = nullptr;
834 bool run_eager_op_as_function_;
835 bool jit_compile_rewrite_;
836 };
837
ContextFromInterface(ImmediateExecutionContext * context)838 inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) {
839 return down_cast<EagerContext*>(context);
840 }
841
842 namespace internal {
843 struct EagerContextDeleter {
operatorEagerContextDeleter844 void operator()(EagerContext* p) const {
845 if (p != nullptr) {
846 p->Release();
847 }
848 }
849 };
850 } // namespace internal
851
852 using EagerContextPtr =
853 std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
854
855 // Sets the EagerContext owned by the current Python eager Context (see
856 // TFE_Py_SetEagerContext in python/eager/pywrap_tfe.h). This is always called
857 // in tandem with TFE_Py_SetEagerContext (but not called by it, because its
858 // py_context argument is opaque).
859 //
860 // Do not use this function in production. It is only intended for testing.
861 // (see _reset_context in context.py).
862 //
863 // Not thread-safe.
864 void SetCEagerContext(EagerContext* ctx);
865
866 // Returns the EagerContext owned by the current Python eager Context (see
867 // TFE_Py_SetEagerContext in pywrap_tfe.h).
868 //
869 // Not thread-safe.
870 EagerContext* GetCEagerContext();
871
872 } // namespace tensorflow
873
874 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
875