1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/worker_session.h"
16
17 namespace tensorflow {
18
19 namespace {
20
21 // A private cache that wraps worker_cache and allows reuse of
22 // WorkerInterface objects.
23 class WorkerFreeListCache : public WorkerCacheInterface {
24 public:
WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)25 explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
26 : wrapped_(std::move(w)) {}
27
~WorkerFreeListCache()28 ~WorkerFreeListCache() final {
29 for (auto& p : workers_) {
30 wrapped_->ReleaseWorker(p.first, p.second.worker);
31 }
32 }
33
ListWorkers(std::vector<string> * workers) const34 void ListWorkers(std::vector<string>* workers) const override {
35 wrapped_->ListWorkers(workers);
36 }
37
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const38 void ListWorkersInJob(const string& job_name,
39 std::vector<string>* workers) const override {
40 wrapped_->ListWorkersInJob(job_name, workers);
41 }
42
CreateWorker(const string & target)43 WorkerInterface* CreateWorker(const string& target) override {
44 mutex_lock l(mu_);
45 auto p = workers_.find(target);
46 if (p != workers_.end()) {
47 return p->second.worker;
48 }
49 WorkerState state;
50 state.worker = wrapped_->CreateWorker(target);
51 if (state.worker != nullptr) {
52 workers_.insert(std::make_pair(target, state));
53 }
54 return state.worker;
55 }
56
ReleaseWorker(const string & target,WorkerInterface * worker)57 void ReleaseWorker(const string& target, WorkerInterface* worker) override {
58 // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
59 }
60
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)61 bool GetDeviceLocalityNonBlocking(const string& device,
62 DeviceLocality* locality) override {
63 return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
64 }
65
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)66 void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
67 StatusCallback done) override {
68 wrapped_->GetDeviceLocalityAsync(device, locality, done);
69 }
70
SetLogging(bool active)71 void SetLogging(bool active) override { wrapped_->SetLogging(active); }
72
ClearLogs()73 void ClearLogs() override { wrapped_->ClearLogs(); }
74
RetrieveLogs(int64 step_id,StepStats * ss)75 bool RetrieveLogs(int64 step_id, StepStats* ss) override {
76 return wrapped_->RetrieveLogs(step_id, ss);
77 }
78
79 private:
80 std::unique_ptr<WorkerCacheInterface> wrapped_;
81
82 // Information kept per created WorkerInterface.
83 struct WorkerState {
84 WorkerInterface* worker;
85 // TODO(jeff,sanjay): Add reference count if we support eviction.
86 };
87
88 // TODO(jeff,sanjay): Eviction when the map becomes too big.
89 mutex mu_;
90 std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
91 };
92
93 } // namespace
94
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<GraphMgr> graph_mgr)95 WorkerSession::WorkerSession(const string& session_name,
96 const string& worker_name,
97 std::unique_ptr<WorkerCacheInterface> worker_cache,
98 std::unique_ptr<DeviceMgr> device_mgr,
99 std::unique_ptr<GraphMgr> graph_mgr)
100 : session_name(session_name),
101 worker_name(worker_name),
102 worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
103 graph_mgr(std::move(graph_mgr)),
104 cluster_flr(
105 new ClusterFunctionLibraryRuntime(this, !session_name.empty())),
106 device_mgr_(std::move(device_mgr)),
107 borrowed_device_mgr_(nullptr) {}
108
109 /* static */
CreateWithBorrowedDeviceMgr(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr)110 std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
111 const string& session_name, const string& worker_name,
112 std::unique_ptr<WorkerCacheInterface> worker_cache,
113 DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr) {
114 return std::shared_ptr<WorkerSession>(
115 new WorkerSession(session_name, worker_name, std::move(worker_cache),
116 borrowed_device_mgr, std::move(graph_mgr)));
117 }
118
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr)119 WorkerSession::WorkerSession(const string& session_name,
120 const string& worker_name,
121 std::unique_ptr<WorkerCacheInterface> worker_cache,
122 DeviceMgr* borrowed_device_mgr,
123 std::unique_ptr<GraphMgr> graph_mgr)
124 : session_name(session_name),
125 worker_name(worker_name),
126 worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
127 graph_mgr(std::move(graph_mgr)),
128 cluster_flr(
129 new ClusterFunctionLibraryRuntime(this, !session_name.empty())),
130 device_mgr_(nullptr),
131 borrowed_device_mgr_(borrowed_device_mgr) {}
132
~WorkerSession()133 WorkerSession::~WorkerSession() {
134 if (graph_mgr) {
135 Status s = graph_mgr->DeregisterAll();
136 if (!s.ok()) {
137 LOG(WARNING) << "Error during worker session deletion: " << s;
138 }
139 }
140 }
141
142 } // namespace tensorflow
143