• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/clusters/single_machine.h"
17 
18 #include <atomic>
19 #include <memory>
20 
21 #include "tensorflow/cc/training/queue_runner.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
25 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
26 #include "tensorflow/core/grappler/clusters/utils.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/kernels/ops_util.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/notification.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/public/session.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 
39 static std::atomic<bool> already_provisioned(false);
40 
SingleMachine(int timeout_s,int num_cpu_cores,int num_gpus)41 SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus)
42     : Cluster(timeout_s), expected_init_time_s_(0), closing_(false) {
43   VLOG(1) << "Number of CPU cores: " << num_cpu_cores
44           << " Number of GPUs: " << num_gpus;
45   thread_pool_.reset(new thread::ThreadPool(
46       Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
47 
48   (*options_.config.mutable_device_count())["CPU"] = 1;
49   if (num_gpus > 0) {
50     (*options_.config.mutable_device_count())["GPU"] = num_gpus;
51   }
52   CHECK_GE(num_cpu_cores, 1);
53   options_.config.set_intra_op_parallelism_threads(num_cpu_cores);
54   // Create a session specific thread pool to ensure the threads are reset when
55   // the session is reset.
56   options_.config.add_session_inter_op_thread_pool()->set_num_threads(
57       num_cpu_cores);
58   if (timeout_s > 0) {
59     options_.config.set_operation_timeout_in_ms(timeout_s * 1000);
60   }
61 }
62 
~SingleMachine()63 SingleMachine::~SingleMachine() {
64   CloseSession(false /*use_timeout*/).IgnoreError();
65 
66   // Reset the thread-pool so that there are no outstanding Session::Run(...)s
67   // when we delete the session.
68   thread_pool_.reset();
69 }
70 
Provision()71 Status SingleMachine::Provision() {
72   // This is really ugly: to avoid leaking variables, we need to reset the tf
73   // session every time we're done processing a grappler item. However,
74   // variables are global, and therefore we can't have more than 1 session alive
75   // at a time. This check detects when more that one cluster is provisioned.
76   if (already_provisioned) {
77     return errors::Unavailable(
78         "Can't provision more than one single cluster at a time");
79   }
80 
81   TF_RETURN_IF_ERROR(ResetSession());
82 
83   std::vector<DeviceAttributes> devices;
84   TF_RETURN_IF_ERROR(session_->ListDevices(&devices));
85   for (const auto& dev : devices) {
86     DeviceProperties attr;
87     if (dev.device_type() == "CPU") {
88       attr = GetLocalCPUInfo();
89     } else if (dev.device_type() == "GPU") {
90       DeviceNameUtils::ParsedName parsed;
91       if (!DeviceNameUtils::ParseFullName(dev.name(), &parsed)) {
92         return errors::InvalidArgument(
93             strings::StrCat("Not able to parse GPU device name: ", dev.name()));
94       }
95       TfGpuId tf_gpu_id(parsed.id);
96       PlatformGpuId platform_gpu_id;
97       Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
98       if (!s.ok()) {
99         return errors::Unavailable("Unknown TF GPU device with id ",
100                                    tf_gpu_id.value(), ": ", s.ToString());
101       }
102       attr = GetLocalGPUInfo(platform_gpu_id);
103     } else if (dev.device_type().find("XLA") == string::npos) {
104       // Filter out the fake XLA devices to avoid double counting the actual
105       // hardware resources that are available.
106       attr.set_type(dev.device_type());
107     }
108     // Overwrite the memory size since users might have requested to use only a
109     // fraction of the available device memory.
110     attr.set_memory_size(dev.memory_limit());
111     devices_[dev.name()] = attr;
112   }
113   already_provisioned = true;
114 
115   // Clear highmark stats of all local allocators.
116   if (cpu_allocator_stats_enabled_) {
117     TF_RETURN_IF_ERROR(ClearAllocatorStats());
118   }
119   return Status::OK();
120 }
121 
Initialize(const GrapplerItem & item)122 Status SingleMachine::Initialize(const GrapplerItem& item) {
123   mutex_lock l(this->last_graph_mu_);
124   if (last_graph_ != &item.graph || last_graph_id_ != item.id) {
125     init_ops_ = item.init_ops;
126     expected_init_time_s_ = item.expected_init_time;
127     last_graph_ = nullptr;
128     queue_runner_defs_ = item.queue_runners;
129     last_graph_id_ = item.id;
130   }
131   return Status::OK();
132 }
133 
Shutdown()134 Status SingleMachine::Shutdown() {
135   TF_RETURN_IF_ERROR(ShutdownSession());
136 
137   mutex_lock l(this->last_graph_mu_);
138   last_graph_ = nullptr;
139   already_provisioned = false;
140 
141   return Status::OK();
142 }
143 
Run(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & feed,const std::vector<string> & fetch,RunMetadata * metadata)144 Status SingleMachine::Run(const GraphDef& graph_def,
145                           const std::vector<std::pair<string, Tensor>>& feed,
146                           const std::vector<string>& fetch,
147                           RunMetadata* metadata) {
148   mutex_lock l(this->last_graph_mu_);
149   if (last_graph_ != &graph_def) {
150     TF_RETURN_IF_ERROR(ResetSession());
151     TF_RETURN_IF_ERROR(session_->Create(graph_def));
152     if (!init_ops_.empty()) {
153       init_metadata_ = RunMetadata();
154       int64 timeout_s = timeout_s_ + expected_init_time_s_;
155       TF_RETURN_IF_ERROR(
156           RunWithTimeout({}, init_ops_, &init_metadata_, timeout_s));
157       // The compute cost for init ops is likely to be pessimistic since init
158       // ops are run only once before warmup. Therefore we only keep their
159       // memory costs.
160       for (auto node : *init_metadata_.mutable_cost_graph()->mutable_node()) {
161         node.clear_compute_cost();
162       }
163       // Also clear the timeline to save memory
164       init_metadata_.clear_step_stats();
165     }
166     // We can have at most one hardware trace. Use it for the main graph, and
167     // downgrade tracing of the queue runners to a software trace.
168     RunOptions queue_options = run_options_;
169     if (queue_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
170       queue_options.set_trace_level(RunOptions::SOFTWARE_TRACE);
171     }
172     for (size_t i = 0; i < queue_runner_defs_.size(); ++i) {
173       std::unique_ptr<QueueRunner> queue_runner;
174       TF_RETURN_IF_ERROR(QueueRunner::New(queue_runner_defs_[i],
175                                           coordinator_.get(), &queue_runner));
176 
177       TF_RETURN_IF_ERROR(queue_runner->StartAndCollectCostGraph(session_.get(),
178                                                                 queue_options));
179       TF_RETURN_IF_ERROR(coordinator_->RegisterRunner(std::move(queue_runner)));
180       TF_RETURN_IF_ERROR(coordinator_->GetStatus());
181     }
182 
183     // Warmup TensorFlow if needed
184     for (int i = 0; i < NumWarmupSteps(); ++i) {
185       TF_RETURN_IF_ERROR(RunWithTimeout(feed, fetch, nullptr));
186     }
187   }
188 
189   if (metadata) {
190     TF_RETURN_IF_ERROR(RunWithTimeout(feed, fetch, metadata));
191     // Merge the costs of the initialization and the queue runners.
192     CostGraphDef queue_costs;
193     TF_RETURN_IF_ERROR(coordinator_->ExportCostGraph(&queue_costs));
194     MergeCosts(metadata->mutable_cost_graph(), init_metadata_.cost_graph(),
195                queue_costs);
196   } else {
197     TF_RETURN_IF_ERROR(RunWithTimeout(feed, fetch, nullptr));
198   }
199 
200   last_graph_ = &graph_def;
201 
202   return Status::OK();
203 }
204 
EnablePeakMemoryStats()205 Status SingleMachine::EnablePeakMemoryStats() {
206   EnableCPUAllocatorStats();
207   cpu_allocator_stats_enabled_ = true;
208   // No need to enable GPU allocator stats since its stats are always collected.
209   return Status::OK();
210 }
211 
GetPeakMemoryUsage(std::unordered_map<string,uint64> * device_peak_memory) const212 Status SingleMachine::GetPeakMemoryUsage(
213     std::unordered_map<string, uint64>* device_peak_memory) const {
214   // Cpu_allocator->TracksAllocationSizes() returns true doesn't always mean the
215   // the AllocatorStats would be collected.
216   if (!cpu_allocator_stats_enabled_) {
217     return Status(error::INVALID_ARGUMENT,
218                   "Tracking allocation for CPU is not enabled.");
219   }
220 
221   const DeviceMgr* device_mgr;
222   TF_RETURN_IF_ERROR(session_->LocalDeviceManager(&device_mgr));
223   std::vector<Device*> devices = device_mgr->ListDevices();
224 
225   device_peak_memory->clear();
226   for (Device* device : devices) {
227     auto* allocator = device->GetAllocator(AllocatorAttributes());
228     if (!allocator->TracksAllocationSizes()) {
229       return Status(error::INVALID_ARGUMENT,
230                     "Tracking allocation is not enabled.");
231     }
232     absl::optional<AllocatorStats> stats = allocator->GetStats();
233     (*device_peak_memory)[device->name()] =
234         (stats ? stats->peak_bytes_in_use : 0);
235   }
236 
237   return Status::OK();
238 }
239 
RunWithTimeout(const std::vector<std::pair<string,Tensor>> & feed,const std::vector<string> & fetch,RunMetadata * run_metadata)240 Status SingleMachine::RunWithTimeout(
241     const std::vector<std::pair<string, Tensor>>& feed,
242     const std::vector<string>& fetch, RunMetadata* run_metadata) {
243   return RunWithTimeout(feed, fetch, run_metadata, timeout_s_);
244 }
245 
RunWithTimeout(const std::vector<std::pair<string,Tensor>> & feed,const std::vector<string> & fetch,RunMetadata * run_metadata,int64 timeout_s)246 Status SingleMachine::RunWithTimeout(
247     const std::vector<std::pair<string, Tensor>>& feed,
248     const std::vector<string>& fetch, RunMetadata* run_metadata,
249     int64 timeout_s) {
250   // We shouldn't be running or closing the session at this point.
251   {
252     mutex_lock l(close_mu_);
253     CHECK(!closing_);
254   }
255 
256   auto status = std::make_shared<Status>();
257   auto local_metadata = std::make_shared<RunMetadata>();
258   const bool executed_in_time = ExecuteWithTimeout(
259       [this, status, local_metadata, feed, fetch]() {
260         *status = session_->Run(run_options_, feed, {}, fetch, nullptr,
261                                 local_metadata.get());
262       },
263       timeout_s * 1000, thread_pool_.get());
264   if (!executed_in_time) {
265     return errors::DeadlineExceeded("Failed to run the graph after ", timeout_s,
266                                     " seconds, aborting");
267   } else if (run_metadata && status->ok()) {
268     *run_metadata = *local_metadata;
269   }
270   return *status;
271 }
272 
CloseSession(bool use_timeout)273 Status SingleMachine::CloseSession(bool use_timeout) {
274   if (!session_ || !thread_pool_) {
275     return Status::OK();
276   }
277 
278   {
279     mutex_lock l(close_mu_);
280 
281     if (!closing_) {
282       closing_ = true;
283     }
284   }
285 
286   const bool executed_in_time = ExecuteWithTimeout(
287       [&]() {
288         if (this->coordinator_) {
289           this->coordinator_->RequestStop().IgnoreError();
290           // Wait for all the runners to have closed their queues.
291           while (!this->coordinator_->AllRunnersStopped()) {
292             sleep(1);
293           }
294           // Now we can close the session. This should cancel any pending I/O
295           // operation.
296           this->session_->Close().IgnoreError();
297           // Last but not least, we can delete the coordinator.
298           this->coordinator_.reset();
299         } else {
300           this->session_->Close().IgnoreError();
301         }
302 
303         mutex_lock l2(close_mu_);
304         closing_ = false;
305       },
306       use_timeout ? timeout_s_ * 1000 : -1, thread_pool_.get());
307 
308   if (!executed_in_time) {
309     // Let the caller know that we can't shutdown the session, and therefore
310     // can't process any further.
311     return errors::Unavailable("Failed to close the previous session after ",
312                                timeout_s_, " seconds, aborting");
313   }
314 
315   return Status::OK();
316 }
317 
ShutdownSession()318 Status SingleMachine::ShutdownSession() {
319   TF_RETURN_IF_ERROR(CloseSession(true /*use_timeout*/));
320 
321   // Delete the threadpool: this ensures that all the pending closures complete
322   // before we return. Note that if TF deadlocked on us, the closures will
323   // never complete, and the call to thread_pool_.reset() will never return:
324   // therefore we need to delete the threadpool with the background thread.
325   // That thread itself will also never complete, so the user should
326   // abort the process to avoid leaking too many resources.
327   auto n = std::make_shared<Notification>();
328   Env::Default()->SchedClosure([this, n]() {
329     thread_pool_.reset();
330     n->Notify();
331   });
332   int64 timeout_us = 1000000ll * timeout_s_;
333   const bool notified = WaitForNotificationWithTimeout(n.get(), timeout_us);
334   if (!notified) {
335     // Let the caller know that we can't shutdown the session properly since
336     // there are calls to Session::Run() still running.
337     return errors::Unavailable("The session is still running graphs after ",
338                                timeout_s_, " seconds");
339   }
340 
341   return Status::OK();
342 }
343 
ResetSession()344 Status SingleMachine::ResetSession() {
345   if (session_) {
346     LOG(INFO) << "Cleaning up previous session";
347 
348     // Make sure the session is properly closed
349     TF_RETURN_IF_ERROR(ShutdownSession());
350 
351     // Destroying the object deletes all its variables as well. This is only
352     // true for DirectSession.
353     session_.reset();
354   }
355 
356   LOG(INFO) << "Starting new session";
357 
358   // Create a new threadpool
359   thread_pool_.reset(new thread::ThreadPool(
360       Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
361 
362   session_.reset(NewSession(options_));
363   if (!session_) {
364     return errors::Unknown("Failed to create session");
365   }
366   coordinator_.reset(new Coordinator());
367 
368   // Build the DeviceSet.
369   device_set_.reset(new DeviceSet);
370   const DeviceMgr* device_mgr;
371   TF_RETURN_IF_ERROR(session_->LocalDeviceManager(&device_mgr));
372   for (auto d : device_mgr->ListDevices()) {
373     device_set_->AddDevice(d);
374     // We currently don't care about the client device.
375   }
376 
377   return Status::OK();
378 }
379 
MergeCosts(CostGraphDef * graph_costs,const CostGraphDef & init_costs,const CostGraphDef & queue_costs)380 void SingleMachine::MergeCosts(CostGraphDef* graph_costs,
381                                const CostGraphDef& init_costs,
382                                const CostGraphDef& queue_costs) {
383   graph_costs->mutable_node()->Reserve(graph_costs->node_size() +
384                                        init_costs.node_size() +
385                                        queue_costs.node_size());
386   std::unordered_set<string> nodes_seen;
387   int queue_costs_id_offset = graph_costs->node_size();
388   for (const auto& node : graph_costs->node()) {
389     nodes_seen.insert(node.name());
390     if (node.id() >= queue_costs_id_offset) {
391       queue_costs_id_offset = node.id() + 1;
392     }
393   }
394 
395   int init_costs_id_offset = queue_costs_id_offset + queue_costs.node_size();
396   // The costs obtained by running the main graph could be more stable than
397   // the one we get from the queue runners since the queue runners run
398   // asynchronously.
399   for (const auto& node : queue_costs.node()) {
400     if (nodes_seen.find(node.name()) != nodes_seen.end()) {
401       continue;
402     }
403 
404     auto* new_node = graph_costs->add_node();
405     new_node->MergeFrom(node);
406 
407     new_node->set_id(node.id() + queue_costs_id_offset);
408     if (new_node->id() >= init_costs_id_offset) {
409       init_costs_id_offset = new_node->id() + 1;
410     }
411 
412     for (auto& input_info : *new_node->mutable_input_info()) {
413       input_info.set_preceding_node(input_info.preceding_node() +
414                                     queue_costs_id_offset);
415     }
416     for (auto& control_input : *new_node->mutable_control_input()) {
417       control_input += queue_costs_id_offset;
418     }
419   }
420 
421   // Don't overwrite the costs with that generated during initialization since
422   // these are possibly outdated.
423   for (const auto& node : init_costs.node()) {
424     if (nodes_seen.find(node.name()) != nodes_seen.end()) {
425       continue;
426     }
427 
428     auto* new_node = graph_costs->add_node();
429     new_node->MergeFrom(node);
430 
431     new_node->set_id(node.id() + init_costs_id_offset);
432     for (auto& input_info : *new_node->mutable_input_info()) {
433       input_info.set_preceding_node(input_info.preceding_node() +
434                                     init_costs_id_offset);
435     }
436     for (auto& control_input : *new_node->mutable_control_input()) {
437       control_input += init_costs_id_offset;
438     }
439   }
440 }
441 
ClearAllocatorStats() const442 Status SingleMachine::ClearAllocatorStats() const {
443   // Cpu_allocator->TracksAllocationSizes() returns true doesn't always mean the
444   // the AllocatorStats would be collected.
445   if (!cpu_allocator_stats_enabled_) {
446     return Status(error::INVALID_ARGUMENT,
447                   "Tracking allocation for CPU is not enabled.");
448   }
449 
450   const DeviceMgr* device_mgr;
451   TF_RETURN_IF_ERROR(session_->LocalDeviceManager(&device_mgr));
452   std::vector<Device*> devices = device_mgr->ListDevices();
453 
454   for (Device* device : devices) {
455     auto* allocator = device->GetAllocator(AllocatorAttributes());
456     if (!allocator->TracksAllocationSizes()) {
457       return Status(error::INVALID_ARGUMENT,
458                     "Tracking allocation is not enabled.");
459     }
460     allocator->ClearStats();
461   }
462   return Status::OK();
463 }
464 
465 }  // namespace grappler
466 }  // namespace tensorflow
467