• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/session_mgr.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/core/common_runtime/device_mgr.h"
21 #include "tensorflow/core/common_runtime/renamed_device.h"
22 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
23 #include "tensorflow/core/distributed_runtime/remote_device.h"
24 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/protobuf/cluster.pb.h"
27 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
28 #include "tensorflow/core/util/ptr_util.h"
29 
30 namespace tensorflow {
31 
SessionMgr(WorkerEnv * worker_env,const string & default_worker_name,std::unique_ptr<WorkerCacheInterface> default_worker_cache,WorkerCacheFactory worker_cache_factory)32 SessionMgr::SessionMgr(
33     WorkerEnv* worker_env, const string& default_worker_name,
34     std::unique_ptr<WorkerCacheInterface> default_worker_cache,
35     WorkerCacheFactory worker_cache_factory)
36     : worker_env_(worker_env),
37       default_worker_cache_(std::move(default_worker_cache)),
38       legacy_session_(WorkerSession::CreateWithBorrowedDeviceMgr(
39           "", default_worker_name,
40           std::unique_ptr<WorkerCacheInterface>(
41               new WorkerCacheWrapper(default_worker_cache_.get())),
42           worker_env->device_mgr,
43           std::unique_ptr<GraphMgr>(
44               new GraphMgr(worker_env, worker_env->device_mgr)),
45           nullptr)),
46       worker_cache_factory_(std::move(worker_cache_factory)) {}
47 
48 /* static */
WorkerNameFromServerDef(const ServerDef & server_def)49 string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
50   return strings::StrCat("/job:", server_def.job_name(),
51                          "/replica:0/task:", server_def.task_index());
52 }
53 
CreateSession(const string & session,const ServerDef & server_def,bool isolate_session_state)54 Status SessionMgr::CreateSession(const string& session,
55                                  const ServerDef& server_def,
56                                  bool isolate_session_state) {
57   return CreateSession(session, server_def, {}, isolate_session_state);
58 }
59 
CreateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)60 Status SessionMgr::CreateSession(
61     const string& session, const ServerDef& server_def,
62     const protobuf::RepeatedPtrField<DeviceAttributes>&
63         cluster_device_attributes,
64     bool isolate_session_state) {
65   mutex_lock l(mu_);
66   if (session.empty()) {
67     return errors::InvalidArgument("Session must be non-empty.");
68   }
69 
70   WorkerCacheInterface* worker_cache = nullptr;
71   string worker_name;
72   if (server_def.cluster().job().empty()) {
73     worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
74     worker_name = legacy_session_->worker_name();
75   } else {
76     TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
77     worker_name = WorkerNameFromServerDef(server_def);
78   }
79 
80   if (worker_cache != nullptr && default_worker_cache_ != nullptr) {
81     worker_cache->SetLogging(this->is_logging_active_);
82   }
83 
84   CHECK(!worker_env_->local_devices.empty())
85       << "The WorkerEnv must have at least one device in `local_devices`.";
86 
87   std::shared_ptr<WorkerSession> worker_session;
88   std::vector<std::unique_ptr<Device>> cluster_devices;
89 
90   if (isolate_session_state || server_def.cluster().job_size()) {
91     if (server_def.cluster().job_size()) {
92       VLOG(1) << "ClusterSpec propagation is enabled.";
93     }
94     if (!isolate_session_state) {
95       VLOG(1) << "Session state isolation is disabled.";
96     }
97 
98     // Create a private copy of the DeviceMgr for the WorkerSession.
99     std::vector<std::unique_ptr<Device>> renamed_devices;
100     for (Device* d : worker_env_->local_devices) {
101       renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
102           worker_name, d, false, isolate_session_state));
103     }
104     auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
105     LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
106       return device_mgr->LookupDevice(name, device);
107     };
108     AsRemoteDevices(worker_env_->env, cluster_device_attributes, cb,
109                     &cluster_devices);
110     std::unique_ptr<DynamicDeviceMgr> remote_devices;
111     if (!cluster_device_attributes.empty()) {
112       remote_devices = MakeUnique<DynamicDeviceMgr>();
113       TF_RETURN_IF_ERROR(
114           remote_devices->AddDevices(std::move(cluster_devices)));
115     }
116 
117     auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
118     worker_session.reset(
119         new WorkerSession(session, worker_name,
120                           std::unique_ptr<WorkerCacheInterface>(worker_cache),
121                           std::move(device_mgr), std::move(graph_mgr),
122                           std::move(remote_devices)));
123   } else {
124     AsRemoteDevices(worker_env_->env, cluster_device_attributes, nullptr,
125                     &cluster_devices);
126     std::unique_ptr<DynamicDeviceMgr> remote_devices;
127     if (!cluster_device_attributes.empty()) {
128       remote_devices = MakeUnique<DynamicDeviceMgr>();
129       TF_RETURN_IF_ERROR(
130           remote_devices->AddDevices(std::move(cluster_devices)));
131     }
132     // Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
133     // that resources using it can use its devices after the
134     // WorkerSession has been deleted.
135     auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr);
136     worker_session = WorkerSession::CreateWithBorrowedDeviceMgr(
137         session, worker_name,
138         std::unique_ptr<WorkerCacheInterface>(worker_cache),
139         worker_env_->device_mgr, std::move(graph_mgr),
140         std::move(remote_devices));
141   }
142 
143   sessions_.insert(std::make_pair(session, std::move(worker_session)));
144   return Status::OK();
145 }
146 
UpdateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)147 Status SessionMgr::UpdateSession(
148     const string& session, const ServerDef& server_def,
149     const protobuf::RepeatedPtrField<DeviceAttributes>&
150         cluster_device_attributes,
151     bool isolate_session_state) {
152   mutex_lock l(mu_);
153   if (session.empty()) {
154     return errors::InvalidArgument("Session must be non-empty.");
155   }
156   auto it = sessions_.find(session);
157   if (it == sessions_.end()) {
158     return errors::InvalidArgument("Cannot update session ", session,
159                                    " because it does not exist.");
160   }
161   std::shared_ptr<WorkerSession> worker_session = it->second;
162 
163   WorkerCacheInterface* worker_cache = nullptr;
164   if (server_def.cluster().job().empty()) {
165     worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
166   } else {
167     TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
168   }
169   std::vector<string> updated_remote_workers;
170   worker_cache->ListWorkers(&updated_remote_workers);
171 
172   std::vector<std::unique_ptr<Device>> cluster_devices;
173 
174   DeviceMgr* local_device_mgr = worker_session->device_mgr();
175   DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr();
176   std::vector<Device*> curr_remote_devices = remote_device_mgr->ListDevices();
177   std::vector<std::unique_ptr<Device>> added_remote_devices;
178   std::vector<Device*> removed_remote_devices;
179 
180   std::vector<DeviceAttributes> added_cluster_device_attrs;
181   for (const auto& da : cluster_device_attributes) {
182     Device* device;
183     if (!local_device_mgr->LookupDevice(da.name(), &device).ok() &&
184         !remote_device_mgr->LookupDevice(da.name(), &device).ok()) {
185       added_cluster_device_attrs.emplace_back(da);
186     } else if (device != nullptr &&
187                device->attributes().incarnation() != da.incarnation()) {
188       removed_remote_devices.emplace_back(device);
189       added_cluster_device_attrs.emplace_back(da);
190     }
191   }
192   for (Device* device : curr_remote_devices) {
193     string task_name;
194     DeviceNameUtils::GetTaskName(device->parsed_name(), &task_name);
195     if (std::find(updated_remote_workers.begin(), updated_remote_workers.end(),
196                   task_name) == updated_remote_workers.end()) {
197       removed_remote_devices.emplace_back(device);
198     }
199   }
200   protobuf::RepeatedPtrField<DeviceAttributes> added_cluster_device_attrs_pb(
201       added_cluster_device_attrs.begin(), added_cluster_device_attrs.end());
202   AsRemoteDevices(worker_env_->env, added_cluster_device_attrs_pb, nullptr,
203                   &added_remote_devices);
204 
205   TF_RETURN_IF_ERROR(worker_session->UpdateWorkerCacheAndDevices(
206       std::unique_ptr<WorkerCacheInterface>(worker_cache),
207       std::move(added_remote_devices), removed_remote_devices));
208   return Status::OK();
209 }
210 
DeleteSession(const string & session)211 Status SessionMgr::DeleteSession(const string& session) {
212   mutex_lock l(mu_);
213   auto it = sessions_.find(session);
214   if (it != sessions_.end()) {
215     sessions_.erase(it);
216   }
217   return Status::OK();
218 }
219 
WorkerSessionForSessionLocked(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)220 Status SessionMgr::WorkerSessionForSessionLocked(
221     const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
222   if (session_handle.empty()) {
223     *out_session = legacy_session_;
224   } else {
225     auto it = sessions_.find(session_handle);
226     if (it == sessions_.end()) {
227       return errors::Aborted("Session handle is not found: ", session_handle,
228                              ". Possibly this worker (\"",
229                              legacy_session_->worker_name(),
230                              "\") just restarted.");
231     } else {
232       *out_session = it->second;
233     }
234   }
235   return Status::OK();
236 }
237 
WorkerSessionForSession(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)238 Status SessionMgr::WorkerSessionForSession(
239     const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
240   mutex_lock l(mu_);
241   return WorkerSessionForSessionLocked(session_handle, out_session);
242 }
243 
LegacySession()244 std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
245   return legacy_session_;
246 }
247 
SetLogging(bool active)248 void SessionMgr::SetLogging(bool active) {
249   mutex_lock l(mu_);
250   this->is_logging_active_ = active;
251   // Legacy Session
252   if (legacy_session_) {
253     auto* worker_cache = legacy_session_->worker_cache();
254     if (worker_cache) {
255       worker_cache->SetLogging(active);
256     }
257   }
258 
259   for (const auto& session_kv : sessions_) {
260     auto session = session_kv.second.get();
261     if (session) {
262       auto* worker_cache = session->worker_cache();
263       if (worker_cache) {
264         worker_cache->SetLogging(active);
265       }
266     }
267   }
268 }
269 
RetrieveLogs(tensorflow::int64 step_id,LoggingResponse * response)270 void SessionMgr::RetrieveLogs(tensorflow::int64 step_id,
271                               LoggingResponse* response) {
272   mutex_lock l(mu_);
273   // Legacy Session
274   if (legacy_session_) {
275     auto* worker_cache = legacy_session_->worker_cache();
276     if (worker_cache) {
277       auto step_stats = StepStats();
278       if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
279         auto* labeled_step_stats = response->add_step();
280         labeled_step_stats->set_step_id(step_id);
281         labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
282       }
283     }
284   }
285   for (const auto& session_kv : sessions_) {
286     auto session = session_kv.second.get();
287     if (session) {
288       auto* worker_cache = session->worker_cache();
289       if (worker_cache) {
290         auto step_stats = StepStats();
291         if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
292           auto* labeled_step_stats = response->add_step();
293           labeled_step_stats->set_step_id(step_id);
294           labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
295         }
296       }
297     }
298   }
299 }
300 
ClearLogs()301 void SessionMgr::ClearLogs() {
302   mutex_lock l(mu_);
303   // Legacy Session
304   if (legacy_session_) {
305     auto* worker_cache = legacy_session_->worker_cache();
306     if (worker_cache) {
307       worker_cache->ClearLogs();
308     }
309   }
310 
311   for (const auto& session_kv : sessions_) {
312     auto session = session_kv.second.get();
313     if (session) {
314       auto* worker_cache = session->worker_cache();
315       if (worker_cache) {
316         worker_cache->ClearLogs();
317       }
318     }
319   }
320 }
321 }  // namespace tensorflow
322