• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/session_mgr.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/core/common_runtime/device_mgr.h"
21 #include "tensorflow/core/common_runtime/renamed_device.h"
22 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
23 #include "tensorflow/core/distributed_runtime/remote_device.h"
24 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/protobuf/cluster.pb.h"
27 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
28 #include "tensorflow/core/util/ptr_util.h"
29 
30 namespace tensorflow {
31 
SessionMgr(WorkerEnv * worker_env,const string & default_worker_name,std::unique_ptr<WorkerCacheInterface> default_worker_cache,WorkerCacheFactory worker_cache_factory)32 SessionMgr::SessionMgr(
33     WorkerEnv* worker_env, const string& default_worker_name,
34     std::unique_ptr<WorkerCacheInterface> default_worker_cache,
35     WorkerCacheFactory worker_cache_factory)
36     : worker_env_(worker_env),
37       default_worker_cache_(std::move(default_worker_cache)),
38       legacy_session_(WorkerSession::CreateWithBorrowedDeviceMgr(
39           "", default_worker_name,
40           std::unique_ptr<WorkerCacheInterface>(
41               new WorkerCacheWrapper(default_worker_cache_.get())),
42           worker_env->device_mgr,
43           std::unique_ptr<GraphMgr>(
44               new GraphMgr(worker_env, worker_env->device_mgr)),
45           nullptr)),
46       worker_cache_factory_(std::move(worker_cache_factory)) {}
47 
48 /* static */
WorkerNameFromServerDef(const ServerDef & server_def)49 string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
50   return strings::StrCat("/job:", server_def.job_name(),
51                          "/replica:0/task:", server_def.task_index());
52 }
53 
CreateSession(const string & session,const ServerDef & server_def,bool isolate_session_state)54 Status SessionMgr::CreateSession(const string& session,
55                                  const ServerDef& server_def,
56                                  bool isolate_session_state) {
57   return CreateSession(session, server_def, {}, isolate_session_state);
58 }
59 
CreateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)60 Status SessionMgr::CreateSession(
61     const string& session, const ServerDef& server_def,
62     const protobuf::RepeatedPtrField<DeviceAttributes>&
63         cluster_device_attributes,
64     bool isolate_session_state) {
65   return CreateSession(session, server_def, cluster_device_attributes,
66                        isolate_session_state, /*master_task=*/"",
67                        /*master_incarnation=*/0);
68 }
69 
CreateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state,string master_task,int64_t master_incarnation)70 Status SessionMgr::CreateSession(
71     const string& session, const ServerDef& server_def,
72     const protobuf::RepeatedPtrField<DeviceAttributes>&
73         cluster_device_attributes,
74     bool isolate_session_state, string master_task,
75     int64_t master_incarnation) {
76   mutex_lock l(mu_);
77   if (session.empty()) {
78     return errors::InvalidArgument("Session must be non-empty.");
79   }
80 
81   // For given master task name, check if one or more `WorkerSession`s have been
82   // created previously on this worker, and if so garbage collect the expired
83   // `WorkerSession`s. This happens when the master fails before sending
84   // `DeleteSession` requests, which can cause `WorkerSession`s to be leaked.
85   if (!master_task.empty()) {
86     auto it_range = master_to_associated_sessions_.equal_range(master_task);
87     if (it_range.first != it_range.second &&
88         it_range.first->second.master_incarnation != master_incarnation) {
89       LOG(INFO) << "When creating WorkerSession for master task " << master_task
90                 << ", found old WorkerSessions created by the same master task "
91                 << "with a different incarnation. These sessions will "
92                 << "be garbage collected. Current WorkerSession count: "
93                 << sessions_.size();
94 
95       auto it = it_range.first;
96       while (it != it_range.second) {
97         auto session_it = sessions_.find(it->second.session_handle);
98         if (session_it != sessions_.end()) {
99           sessions_.erase(session_it);
100         }
101         it = master_to_associated_sessions_.erase(it);
102       }
103     }
104   }
105 
106   WorkerCacheInterface* worker_cache = nullptr;
107   string worker_name;
108   if (server_def.cluster().job().empty()) {
109     worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
110     worker_name = legacy_session_->worker_name();
111   } else {
112     TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
113     worker_name = WorkerNameFromServerDef(server_def);
114   }
115 
116   if (worker_cache != nullptr && default_worker_cache_ != nullptr) {
117     worker_cache->SetLogging(this->is_logging_active_);
118   }
119 
120   CHECK(!worker_env_->local_devices.empty())
121       << "The WorkerEnv must have at least one device in `local_devices`.";
122 
123   std::shared_ptr<WorkerSession> worker_session;
124   std::vector<std::unique_ptr<Device>> cluster_devices;
125 
126   if (isolate_session_state || server_def.cluster().job_size()) {
127     if (server_def.cluster().job_size()) {
128       VLOG(1) << "ClusterSpec propagation is enabled.";
129     }
130     if (!isolate_session_state) {
131       VLOG(1) << "Session state isolation is disabled.";
132     }
133 
134     // Create a private copy of the DeviceMgr for the WorkerSession.
135     std::vector<std::unique_ptr<Device>> renamed_devices;
136     for (Device* d : worker_env_->local_devices) {
137       renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
138           worker_name, d, false, isolate_session_state));
139     }
140     auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
141     LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
142       return device_mgr->LookupDevice(name, device);
143     };
144     AsRemoteDevices(worker_env_->env, cluster_device_attributes, cb,
145                     &cluster_devices);
146     std::unique_ptr<DynamicDeviceMgr> remote_devices;
147     if (!cluster_device_attributes.empty()) {
148       remote_devices = MakeUnique<DynamicDeviceMgr>();
149       TF_RETURN_IF_ERROR(
150           remote_devices->AddDevices(std::move(cluster_devices)));
151     }
152 
153     auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
154     worker_session.reset(
155         new WorkerSession(session, worker_name,
156                           std::unique_ptr<WorkerCacheInterface>(worker_cache),
157                           std::move(device_mgr), std::move(graph_mgr),
158                           std::move(remote_devices)));
159   } else {
160     AsRemoteDevices(worker_env_->env, cluster_device_attributes, nullptr,
161                     &cluster_devices);
162     std::unique_ptr<DynamicDeviceMgr> remote_devices;
163     if (!cluster_device_attributes.empty()) {
164       remote_devices = MakeUnique<DynamicDeviceMgr>();
165       TF_RETURN_IF_ERROR(
166           remote_devices->AddDevices(std::move(cluster_devices)));
167     }
168     // Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
169     // that resources using it can use its devices after the
170     // WorkerSession has been deleted.
171     auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr);
172     worker_session = WorkerSession::CreateWithBorrowedDeviceMgr(
173         session, worker_name,
174         std::unique_ptr<WorkerCacheInterface>(worker_cache),
175         worker_env_->device_mgr, std::move(graph_mgr),
176         std::move(remote_devices));
177   }
178 
179   sessions_.insert(std::make_pair(session, std::move(worker_session)));
180   if (!master_task.empty()) {
181     MasterAssociatedSession s{master_incarnation, session};
182     master_to_associated_sessions_.emplace(master_task, s);
183   }
184   return Status::OK();
185 }
186 
ResetDefaultWorkerCache(WorkerCacheInterface * worker_cache)187 void SessionMgr::ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache) {
188   default_worker_cache_.reset(worker_cache);
189 }
190 
UpdateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)191 Status SessionMgr::UpdateSession(
192     const string& session, const ServerDef& server_def,
193     const protobuf::RepeatedPtrField<DeviceAttributes>&
194         cluster_device_attributes,
195     bool isolate_session_state) {
196   mutex_lock l(mu_);
197   if (session.empty()) {
198     return errors::InvalidArgument("Session must be non-empty.");
199   }
200   auto it = sessions_.find(session);
201   if (it == sessions_.end()) {
202     return errors::InvalidArgument("Cannot update session ", session,
203                                    " because it does not exist.");
204   }
205   std::shared_ptr<WorkerSession> worker_session = it->second;
206 
207   WorkerCacheInterface* worker_cache = nullptr;
208   if (server_def.cluster().job().empty()) {
209     worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
210   } else {
211     TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
212   }
213   std::vector<string> updated_remote_workers;
214   worker_cache->ListWorkers(&updated_remote_workers);
215 
216   std::vector<std::unique_ptr<Device>> cluster_devices;
217 
218   const DeviceMgr* local_device_mgr = worker_session->device_mgr();
219   DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr();
220   std::vector<Device*> curr_remote_devices = remote_device_mgr->ListDevices();
221   std::vector<std::unique_ptr<Device>> added_remote_devices;
222   std::vector<Device*> removed_remote_devices;
223 
224   std::vector<DeviceAttributes> added_cluster_device_attrs;
225   for (const auto& da : cluster_device_attributes) {
226     Device* device;
227     if (!local_device_mgr->LookupDevice(da.name(), &device).ok() &&
228         !remote_device_mgr->LookupDevice(da.name(), &device).ok()) {
229       added_cluster_device_attrs.emplace_back(da);
230     } else if (device != nullptr &&
231                device->attributes().incarnation() != da.incarnation()) {
232       removed_remote_devices.emplace_back(device);
233       added_cluster_device_attrs.emplace_back(da);
234     }
235   }
236   for (Device* device : curr_remote_devices) {
237     string task_name;
238     DeviceNameUtils::GetTaskName(device->parsed_name(), &task_name);
239     if (std::find(updated_remote_workers.begin(), updated_remote_workers.end(),
240                   task_name) == updated_remote_workers.end()) {
241       removed_remote_devices.emplace_back(device);
242     }
243   }
244   protobuf::RepeatedPtrField<DeviceAttributes> added_cluster_device_attrs_pb(
245       added_cluster_device_attrs.begin(), added_cluster_device_attrs.end());
246   AsRemoteDevices(worker_env_->env, added_cluster_device_attrs_pb, nullptr,
247                   &added_remote_devices);
248 
249   TF_RETURN_IF_ERROR(worker_session->UpdateWorkerCacheAndDevices(
250       std::unique_ptr<WorkerCacheInterface>(worker_cache),
251       std::move(added_remote_devices), removed_remote_devices));
252   return Status::OK();
253 }
254 
DeleteSession(const string & session)255 Status SessionMgr::DeleteSession(const string& session) {
256   mutex_lock l(mu_);
257   auto it = sessions_.find(session);
258   if (it != sessions_.end()) {
259     sessions_.erase(it);
260   }
261   return Status::OK();
262 }
263 
WorkerSessionForSessionLocked(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)264 Status SessionMgr::WorkerSessionForSessionLocked(
265     const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
266   if (session_handle.empty()) {
267     *out_session = legacy_session_;
268   } else {
269     auto it = sessions_.find(session_handle);
270     if (it == sessions_.end()) {
271       return errors::Aborted("Session handle is not found: ", session_handle,
272                              ". Possibly this worker (\"",
273                              legacy_session_->worker_name(),
274                              "\") just restarted.");
275     } else {
276       *out_session = it->second;
277     }
278   }
279   return Status::OK();
280 }
281 
WorkerSessionForSession(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)282 Status SessionMgr::WorkerSessionForSession(
283     const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
284   mutex_lock l(mu_);
285   return WorkerSessionForSessionLocked(session_handle, out_session);
286 }
287 
LegacySession()288 std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
289   return legacy_session_;
290 }
291 
SetLogging(bool active)292 void SessionMgr::SetLogging(bool active) {
293   mutex_lock l(mu_);
294   this->is_logging_active_ = active;
295   // Legacy Session
296   if (legacy_session_) {
297     auto* worker_cache = legacy_session_->worker_cache();
298     if (worker_cache) {
299       worker_cache->SetLogging(active);
300     }
301   }
302 
303   for (const auto& session_kv : sessions_) {
304     auto session = session_kv.second.get();
305     if (session) {
306       auto* worker_cache = session->worker_cache();
307       if (worker_cache) {
308         worker_cache->SetLogging(active);
309       }
310     }
311   }
312 }
313 
RetrieveLogs(int64_t step_id,LoggingResponse * response)314 void SessionMgr::RetrieveLogs(int64_t step_id, LoggingResponse* response) {
315   mutex_lock l(mu_);
316   // Legacy Session
317   if (legacy_session_) {
318     auto* worker_cache = legacy_session_->worker_cache();
319     if (worker_cache) {
320       auto step_stats = StepStats();
321       if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
322         auto* labeled_step_stats = response->add_step();
323         labeled_step_stats->set_step_id(step_id);
324         labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
325       }
326     }
327   }
328   for (const auto& session_kv : sessions_) {
329     auto session = session_kv.second.get();
330     if (session) {
331       auto* worker_cache = session->worker_cache();
332       if (worker_cache) {
333         auto step_stats = StepStats();
334         if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
335           auto* labeled_step_stats = response->add_step();
336           labeled_step_stats->set_step_id(step_id);
337           labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
338         }
339       }
340     }
341   }
342 }
343 
ClearLogs()344 void SessionMgr::ClearLogs() {
345   mutex_lock l(mu_);
346   // Legacy Session
347   if (legacy_session_) {
348     auto* worker_cache = legacy_session_->worker_cache();
349     if (worker_cache) {
350       worker_cache->ClearLogs();
351     }
352   }
353 
354   for (const auto& session_kv : sessions_) {
355     auto session = session_kv.second.get();
356     if (session) {
357       auto* worker_cache = session->worker_cache();
358       if (worker_cache) {
359         worker_cache->ClearLogs();
360       }
361     }
362   }
363 }
364 }  // namespace tensorflow
365