1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/session_mgr.h"
17
18 #include <utility>
19
20 #include "tensorflow/core/common_runtime/device_mgr.h"
21 #include "tensorflow/core/common_runtime/renamed_device.h"
22 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
23 #include "tensorflow/core/distributed_runtime/remote_device.h"
24 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/protobuf/cluster.pb.h"
27 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
28 #include "tensorflow/core/util/ptr_util.h"
29
30 namespace tensorflow {
31
SessionMgr(WorkerEnv * worker_env,const string & default_worker_name,std::unique_ptr<WorkerCacheInterface> default_worker_cache,WorkerCacheFactory worker_cache_factory)32 SessionMgr::SessionMgr(
33 WorkerEnv* worker_env, const string& default_worker_name,
34 std::unique_ptr<WorkerCacheInterface> default_worker_cache,
35 WorkerCacheFactory worker_cache_factory)
36 : worker_env_(worker_env),
37 default_worker_cache_(std::move(default_worker_cache)),
38 legacy_session_(WorkerSession::CreateWithBorrowedDeviceMgr(
39 "", default_worker_name,
40 std::unique_ptr<WorkerCacheInterface>(
41 new WorkerCacheWrapper(default_worker_cache_.get())),
42 worker_env->device_mgr,
43 std::unique_ptr<GraphMgr>(
44 new GraphMgr(worker_env, worker_env->device_mgr)),
45 nullptr)),
46 worker_cache_factory_(std::move(worker_cache_factory)) {}
47
48 /* static */
WorkerNameFromServerDef(const ServerDef & server_def)49 string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
50 return strings::StrCat("/job:", server_def.job_name(),
51 "/replica:0/task:", server_def.task_index());
52 }
53
CreateSession(const string & session,const ServerDef & server_def,bool isolate_session_state)54 Status SessionMgr::CreateSession(const string& session,
55 const ServerDef& server_def,
56 bool isolate_session_state) {
57 return CreateSession(session, server_def, {}, isolate_session_state);
58 }
59
CreateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)60 Status SessionMgr::CreateSession(
61 const string& session, const ServerDef& server_def,
62 const protobuf::RepeatedPtrField<DeviceAttributes>&
63 cluster_device_attributes,
64 bool isolate_session_state) {
65 return CreateSession(session, server_def, cluster_device_attributes,
66 isolate_session_state, /*master_task=*/"",
67 /*master_incarnation=*/0);
68 }
69
CreateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state,string master_task,int64_t master_incarnation)70 Status SessionMgr::CreateSession(
71 const string& session, const ServerDef& server_def,
72 const protobuf::RepeatedPtrField<DeviceAttributes>&
73 cluster_device_attributes,
74 bool isolate_session_state, string master_task,
75 int64_t master_incarnation) {
76 mutex_lock l(mu_);
77 if (session.empty()) {
78 return errors::InvalidArgument("Session must be non-empty.");
79 }
80
81 // For given master task name, check if one or more `WorkerSession`s have been
82 // created previously on this worker, and if so garbage collect the expired
83 // `WorkerSession`s. This happens when the master fails before sending
84 // `DeleteSession` requests, which can cause `WorkerSession`s to be leaked.
85 if (!master_task.empty()) {
86 auto it_range = master_to_associated_sessions_.equal_range(master_task);
87 if (it_range.first != it_range.second &&
88 it_range.first->second.master_incarnation != master_incarnation) {
89 LOG(INFO) << "When creating WorkerSession for master task " << master_task
90 << ", found old WorkerSessions created by the same master task "
91 << "with a different incarnation. These sessions will "
92 << "be garbage collected. Current WorkerSession count: "
93 << sessions_.size();
94
95 auto it = it_range.first;
96 while (it != it_range.second) {
97 auto session_it = sessions_.find(it->second.session_handle);
98 if (session_it != sessions_.end()) {
99 sessions_.erase(session_it);
100 }
101 it = master_to_associated_sessions_.erase(it);
102 }
103 }
104 }
105
106 WorkerCacheInterface* worker_cache = nullptr;
107 string worker_name;
108 if (server_def.cluster().job().empty()) {
109 worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
110 worker_name = legacy_session_->worker_name();
111 } else {
112 TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
113 worker_name = WorkerNameFromServerDef(server_def);
114 }
115
116 if (worker_cache != nullptr && default_worker_cache_ != nullptr) {
117 worker_cache->SetLogging(this->is_logging_active_);
118 }
119
120 CHECK(!worker_env_->local_devices.empty())
121 << "The WorkerEnv must have at least one device in `local_devices`.";
122
123 std::shared_ptr<WorkerSession> worker_session;
124 std::vector<std::unique_ptr<Device>> cluster_devices;
125
126 if (isolate_session_state || server_def.cluster().job_size()) {
127 if (server_def.cluster().job_size()) {
128 VLOG(1) << "ClusterSpec propagation is enabled.";
129 }
130 if (!isolate_session_state) {
131 VLOG(1) << "Session state isolation is disabled.";
132 }
133
134 // Create a private copy of the DeviceMgr for the WorkerSession.
135 std::vector<std::unique_ptr<Device>> renamed_devices;
136 for (Device* d : worker_env_->local_devices) {
137 renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
138 worker_name, d, false, isolate_session_state));
139 }
140 auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
141 LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
142 return device_mgr->LookupDevice(name, device);
143 };
144 AsRemoteDevices(worker_env_->env, cluster_device_attributes, cb,
145 &cluster_devices);
146 std::unique_ptr<DynamicDeviceMgr> remote_devices;
147 if (!cluster_device_attributes.empty()) {
148 remote_devices = MakeUnique<DynamicDeviceMgr>();
149 TF_RETURN_IF_ERROR(
150 remote_devices->AddDevices(std::move(cluster_devices)));
151 }
152
153 auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
154 worker_session.reset(
155 new WorkerSession(session, worker_name,
156 std::unique_ptr<WorkerCacheInterface>(worker_cache),
157 std::move(device_mgr), std::move(graph_mgr),
158 std::move(remote_devices)));
159 } else {
160 AsRemoteDevices(worker_env_->env, cluster_device_attributes, nullptr,
161 &cluster_devices);
162 std::unique_ptr<DynamicDeviceMgr> remote_devices;
163 if (!cluster_device_attributes.empty()) {
164 remote_devices = MakeUnique<DynamicDeviceMgr>();
165 TF_RETURN_IF_ERROR(
166 remote_devices->AddDevices(std::move(cluster_devices)));
167 }
168 // Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
169 // that resources using it can use its devices after the
170 // WorkerSession has been deleted.
171 auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr);
172 worker_session = WorkerSession::CreateWithBorrowedDeviceMgr(
173 session, worker_name,
174 std::unique_ptr<WorkerCacheInterface>(worker_cache),
175 worker_env_->device_mgr, std::move(graph_mgr),
176 std::move(remote_devices));
177 }
178
179 sessions_.insert(std::make_pair(session, std::move(worker_session)));
180 if (!master_task.empty()) {
181 MasterAssociatedSession s{master_incarnation, session};
182 master_to_associated_sessions_.emplace(master_task, s);
183 }
184 return Status::OK();
185 }
186
ResetDefaultWorkerCache(WorkerCacheInterface * worker_cache)187 void SessionMgr::ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache) {
188 default_worker_cache_.reset(worker_cache);
189 }
190
UpdateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)191 Status SessionMgr::UpdateSession(
192 const string& session, const ServerDef& server_def,
193 const protobuf::RepeatedPtrField<DeviceAttributes>&
194 cluster_device_attributes,
195 bool isolate_session_state) {
196 mutex_lock l(mu_);
197 if (session.empty()) {
198 return errors::InvalidArgument("Session must be non-empty.");
199 }
200 auto it = sessions_.find(session);
201 if (it == sessions_.end()) {
202 return errors::InvalidArgument("Cannot update session ", session,
203 " because it does not exist.");
204 }
205 std::shared_ptr<WorkerSession> worker_session = it->second;
206
207 WorkerCacheInterface* worker_cache = nullptr;
208 if (server_def.cluster().job().empty()) {
209 worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
210 } else {
211 TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
212 }
213 std::vector<string> updated_remote_workers;
214 worker_cache->ListWorkers(&updated_remote_workers);
215
216 std::vector<std::unique_ptr<Device>> cluster_devices;
217
218 const DeviceMgr* local_device_mgr = worker_session->device_mgr();
219 DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr();
220 std::vector<Device*> curr_remote_devices = remote_device_mgr->ListDevices();
221 std::vector<std::unique_ptr<Device>> added_remote_devices;
222 std::vector<Device*> removed_remote_devices;
223
224 std::vector<DeviceAttributes> added_cluster_device_attrs;
225 for (const auto& da : cluster_device_attributes) {
226 Device* device;
227 if (!local_device_mgr->LookupDevice(da.name(), &device).ok() &&
228 !remote_device_mgr->LookupDevice(da.name(), &device).ok()) {
229 added_cluster_device_attrs.emplace_back(da);
230 } else if (device != nullptr &&
231 device->attributes().incarnation() != da.incarnation()) {
232 removed_remote_devices.emplace_back(device);
233 added_cluster_device_attrs.emplace_back(da);
234 }
235 }
236 for (Device* device : curr_remote_devices) {
237 string task_name;
238 DeviceNameUtils::GetTaskName(device->parsed_name(), &task_name);
239 if (std::find(updated_remote_workers.begin(), updated_remote_workers.end(),
240 task_name) == updated_remote_workers.end()) {
241 removed_remote_devices.emplace_back(device);
242 }
243 }
244 protobuf::RepeatedPtrField<DeviceAttributes> added_cluster_device_attrs_pb(
245 added_cluster_device_attrs.begin(), added_cluster_device_attrs.end());
246 AsRemoteDevices(worker_env_->env, added_cluster_device_attrs_pb, nullptr,
247 &added_remote_devices);
248
249 TF_RETURN_IF_ERROR(worker_session->UpdateWorkerCacheAndDevices(
250 std::unique_ptr<WorkerCacheInterface>(worker_cache),
251 std::move(added_remote_devices), removed_remote_devices));
252 return Status::OK();
253 }
254
DeleteSession(const string & session)255 Status SessionMgr::DeleteSession(const string& session) {
256 mutex_lock l(mu_);
257 auto it = sessions_.find(session);
258 if (it != sessions_.end()) {
259 sessions_.erase(it);
260 }
261 return Status::OK();
262 }
263
WorkerSessionForSessionLocked(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)264 Status SessionMgr::WorkerSessionForSessionLocked(
265 const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
266 if (session_handle.empty()) {
267 *out_session = legacy_session_;
268 } else {
269 auto it = sessions_.find(session_handle);
270 if (it == sessions_.end()) {
271 return errors::Aborted("Session handle is not found: ", session_handle,
272 ". Possibly this worker (\"",
273 legacy_session_->worker_name(),
274 "\") just restarted.");
275 } else {
276 *out_session = it->second;
277 }
278 }
279 return Status::OK();
280 }
281
WorkerSessionForSession(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)282 Status SessionMgr::WorkerSessionForSession(
283 const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
284 mutex_lock l(mu_);
285 return WorkerSessionForSessionLocked(session_handle, out_session);
286 }
287
LegacySession()288 std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
289 return legacy_session_;
290 }
291
SetLogging(bool active)292 void SessionMgr::SetLogging(bool active) {
293 mutex_lock l(mu_);
294 this->is_logging_active_ = active;
295 // Legacy Session
296 if (legacy_session_) {
297 auto* worker_cache = legacy_session_->worker_cache();
298 if (worker_cache) {
299 worker_cache->SetLogging(active);
300 }
301 }
302
303 for (const auto& session_kv : sessions_) {
304 auto session = session_kv.second.get();
305 if (session) {
306 auto* worker_cache = session->worker_cache();
307 if (worker_cache) {
308 worker_cache->SetLogging(active);
309 }
310 }
311 }
312 }
313
RetrieveLogs(int64_t step_id,LoggingResponse * response)314 void SessionMgr::RetrieveLogs(int64_t step_id, LoggingResponse* response) {
315 mutex_lock l(mu_);
316 // Legacy Session
317 if (legacy_session_) {
318 auto* worker_cache = legacy_session_->worker_cache();
319 if (worker_cache) {
320 auto step_stats = StepStats();
321 if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
322 auto* labeled_step_stats = response->add_step();
323 labeled_step_stats->set_step_id(step_id);
324 labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
325 }
326 }
327 }
328 for (const auto& session_kv : sessions_) {
329 auto session = session_kv.second.get();
330 if (session) {
331 auto* worker_cache = session->worker_cache();
332 if (worker_cache) {
333 auto step_stats = StepStats();
334 if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
335 auto* labeled_step_stats = response->add_step();
336 labeled_step_stats->set_step_id(step_id);
337 labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
338 }
339 }
340 }
341 }
342 }
343
ClearLogs()344 void SessionMgr::ClearLogs() {
345 mutex_lock l(mu_);
346 // Legacy Session
347 if (legacy_session_) {
348 auto* worker_cache = legacy_session_->worker_cache();
349 if (worker_cache) {
350 worker_cache->ClearLogs();
351 }
352 }
353
354 for (const auto& session_kv : sessions_) {
355 auto session = session_kv.second.get();
356 if (session) {
357 auto* worker_cache = session->worker_cache();
358 if (worker_cache) {
359 worker_cache->ClearLogs();
360 }
361 }
362 }
363 }
364 } // namespace tensorflow
365