• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/step_stats_collector.h"
17 #include "tensorflow/core/common_runtime/costmodel_manager.h"
18 #include "tensorflow/core/framework/allocation_description.pb.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_description.pb.h"
22 #include "tensorflow/core/framework/tracking_allocator.h"
23 #include "tensorflow/core/graph/costmodel.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/lib/strings/numbers.h"
27 #include "tensorflow/core/lib/strings/scanner.h"
28 #include "tensorflow/core/lib/strings/stringprintf.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/util/ptr_util.h"
31 
32 namespace tensorflow {
33 namespace {
34 const int kMaxAllocReportNodes = 100;
35 const float kMaxAllocReportFraction = 0.99;
36 
37 struct AllocStats {
38   std::map<int64, std::vector<string>> nodes_by_size;
39   int64 total_bytes = 0;
40   int64 total_nodes = 0;
41 };
42 
IsRecv(const NodeDef * node)43 bool IsRecv(const NodeDef* node) {
44   return node->op() == "_Recv" || node->op() == "_HostRecv";
45 }
46 
IsSend(const NodeDef * node)47 bool IsSend(const NodeDef* node) {
48   return node->op() == "_Send" || node->op() == "_HostSend";
49 }
50 
51 }  // namespace
52 
NodeExecStatsWrapper(const NodeDef * node,StepStatsCollector * step_stats_collector)53 NodeExecStatsWrapper::NodeExecStatsWrapper(
54     const NodeDef* node, StepStatsCollector* step_stats_collector)
55     : NodeExecStatsWrapper(MakeUnique<NodeExecStats>(), node,
56                            step_stats_collector) {
57   stats_->set_node_name(node->name());
58 }
59 
NodeExecStatsWrapper(std::unique_ptr<NodeExecStats> stats,const NodeDef * node,StepStatsCollector * step_stats_collector)60 NodeExecStatsWrapper::NodeExecStatsWrapper(
61     std::unique_ptr<NodeExecStats> stats, const NodeDef* node,
62     StepStatsCollector* step_stats_collector)
63     : stats_(std::move(stats)),
64       node_(node),
65       step_stats_collector_(step_stats_collector) {}
66 
Done(const string & device)67 void NodeExecStatsWrapper::Done(const string& device) {
68   // TODO(tucker): merge with the DetailText function in session.cc in a common
69   // location.
70   DCHECK(node_);
71   string memory;
72   for (auto& all : stats_->memory()) {
73     int64 tot = all.total_bytes();
74     if (tot >= 0.1 * 1048576.0) {
75       int64 peak = all.peak_bytes();
76       if (peak > 0) {
77         memory =
78             strings::StrCat(memory, "[", all.allocator_name(),
79                             strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0,
80                                             peak / 1048576.0));
81       } else {
82         memory = strings::StrCat(memory, "[", all.allocator_name(),
83                                  strings::Printf(" %.1fMB] ", tot / 1048576.0));
84       }
85     }
86   }
87   const AttrSlice attrs(*node_);
88   string text;
89   if (IsSend(node_)) {
90     string tensor_name;
91     TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
92     string recv_device;
93     TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
94     text = strings::StrCat(memory, node_->name(), " = ", node_->op(), "(",
95                            tensor_name, " @", recv_device, ")");
96   } else if (IsRecv(node_)) {
97     string tensor_name;
98     TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
99     string send_device;
100     TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
101     text = strings::StrCat(memory, node_->name(), " = ", node_->op(), "(",
102                            tensor_name, " @", send_device, ")");
103   } else {
104     text = strings::StrCat(memory, node_->name(), " = ", node_->op(), "(",
105                            absl::StrJoin(node_->input(), ", "), ")");
106   }
107   stats_->set_timeline_label(text);
108   step_stats_collector_->Save(device, this);
109 }
110 
RecordExecutorStarted()111 void NodeExecStatsWrapper::RecordExecutorStarted() {
112   int64 now_nanos = Env::Default()->NowNanos();
113   stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
114   stats_->set_all_start_nanos(now_nanos);
115 }
116 
RecordComputeStarted()117 void NodeExecStatsWrapper::RecordComputeStarted() {
118   int64 now_nanos = Env::Default()->NowNanos();
119   DCHECK_NE(stats_->all_start_micros(), 0);
120   DCHECK_NE(stats_->all_start_nanos(), 0);
121   stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
122                                   stats_->all_start_micros());
123   stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
124 }
125 
RecordComputeEnded()126 void NodeExecStatsWrapper::RecordComputeEnded() {
127   int64 now_nanos = Env::Default()->NowNanos();
128   DCHECK_NE(stats_->all_start_micros(), 0);
129   DCHECK_NE(stats_->all_start_nanos(), 0);
130   stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
131                                 stats_->all_start_micros());
132   stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
133 }
134 
RecordExecutorEnded()135 void NodeExecStatsWrapper::RecordExecutorEnded() {
136   int64 now_nanos = Env::Default()->NowNanos();
137   DCHECK_NE(stats_->all_start_micros(), 0);
138   DCHECK_NE(stats_->all_start_nanos(), 0);
139   stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
140                                  stats_->all_start_micros());
141   stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
142 }
143 
SetScheduled(int64 nanos)144 void NodeExecStatsWrapper::SetScheduled(int64 nanos) {
145   stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
146   stats_->set_scheduled_nanos(nanos);
147 }
148 
SetMemory(OpKernelContext * ctx)149 void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
150   for (const auto& allocator_pair : ctx->ConsumeWrappedAllocators()) {
151     AddAllocation(allocator_pair.first, allocator_pair.second);
152   }
153   auto* ms = stats_->mutable_memory_stats();
154   ms->set_temp_memory_size(ctx->temp_memory_allocated());
155   for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
156     ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
157   }
158   ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
159 }
160 
SetOutput(int slot,const Tensor * tensor)161 void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* tensor) {
162   DCHECK(tensor);
163   NodeOutput* node_output = stats_->add_output();
164   node_output->set_slot(slot);
165   tensor->FillDescription(node_output->mutable_tensor_description());
166 }
167 
AddAllocation(Allocator * allocator,TrackingAllocator * tracking_allocator)168 void NodeExecStatsWrapper::AddAllocation(
169     Allocator* allocator, TrackingAllocator* tracking_allocator) {
170   AllocatorMemoryUsed* memory = stats_->add_memory();
171   memory->set_allocator_name(allocator->Name());
172   auto sizes = tracking_allocator->GetSizes();
173   memory->set_total_bytes(std::get<0>(sizes));
174   memory->set_peak_bytes(std::get<1>(sizes));
175   memory->set_live_bytes(std::get<2>(sizes));
176 
177   absl::optional<AllocatorStats> stats = allocator->GetStats();
178   if (stats) {
179     memory->set_allocator_bytes_in_use(stats->bytes_in_use);
180   }
181   allocations_.push_back(std::make_pair(memory, tracking_allocator));
182 }
183 
Finalize()184 void NodeExecStatsWrapper::Finalize() {
185   for (auto& alloc : allocations_) {
186     AllocatorMemoryUsed* memory = alloc.first;
187     for (auto& record : alloc.second->GetRecordsAndUnRef()) {
188       auto* r = memory->add_allocation_records();
189       r->set_alloc_bytes(record.alloc_bytes);
190       r->set_alloc_micros(record.alloc_micros);
191     }
192   }
193   allocations_.clear();
194 }
195 
StepStatsCollector(StepStats * step_stats)196 StepStatsCollector::StepStatsCollector(StepStats* step_stats)
197     : finalized_(false), step_stats_(step_stats) {}
198 
ExtractGpuWithStreamAll(string device_name)199 static int ExtractGpuWithStreamAll(string device_name) {
200   // Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp,
201   // and if it does return the stream index (always positive). If it doesn't
202   // return -1.
203 
204   // The best way to parse this regexp using a scanner is to parse it in
205   // reverse starting from the end.
206   std::reverse(device_name.begin(), device_name.end());
207   strings::Scanner scanner(device_name);
208   // Check that the string end with '/stream:all'
209   scanner.OneLiteral("lla:maerts/");
210   // Capture the digits if present
211   scanner.RestartCapture().Many(strings::Scanner::DIGIT).StopCapture();
212   // Check that the digits are preceded by the 'device:GPU:' string
213   scanner.OneLiteral(":UPG:ecived");
214   StringPiece capture;
215   bool matched = scanner.GetResult(nullptr, &capture);
216 
217   if (!matched) {
218     return -1;
219   } else {
220     // Convert the captured string into an integer. But first we need to put
221     // the digits back in order
222     string ordered_capture(capture);
223     std::reverse(ordered_capture.begin(), ordered_capture.end());
224     int gpu_id;
225     CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
226     return gpu_id;
227   }
228 }
229 
ExtractGpuWithoutStream(string device_name)230 static int ExtractGpuWithoutStream(string device_name) {
231   // Check if the device name matches the ".*gpu:(\\d+)$" regexp,
232   // and if it does return the stream index (always positive). If it doesn't
233   // return -1.
234 
235   // The best way to parse this regexp using a scanner is to parse it in
236   // reverse starting from the end.
237   std::reverse(device_name.begin(), device_name.end());
238   strings::Scanner scanner(device_name);
239   // Capture the trailing digits if present
240   scanner.RestartCapture().Many(strings::Scanner::DIGIT).StopCapture();
241   // Check that the digits are preceded by the 'device:GPU:' string
242   scanner.OneLiteral(":UPG:ecived");
243   StringPiece capture;
244   bool matched = scanner.GetResult(nullptr, &capture);
245 
246   if (!matched) {
247     return -1;
248   } else {
249     // Convert the captured string into an integer. But first we need to put
250     // the digits back in order
251     string ordered_capture(capture);
252     std::reverse(ordered_capture.begin(), ordered_capture.end());
253     int gpu_id;
254     CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
255     return gpu_id;
256   }
257 }
258 
BuildCostModel(CostModelManager * cost_model_manager,const std::unordered_map<string,const Graph * > & device_map)259 void StepStatsCollector::BuildCostModel(
260     CostModelManager* cost_model_manager,
261     const std::unordered_map<string, const Graph*>& device_map) {
262   mutex_lock lock(mu_);
263 
264   if (!finalized_) {
265     FinalizeInternal();
266   }
267   // Hardware stats for gpu are available under a fake device named
268   // "gpu:<id>/stream::all.
269   // Use them instead of regular stats whenever they're available to extract
270   // the execution stats of a particular node since they're more accurate.
271   // However hardware traces don't record memory usage, so we still have to
272   // rely on regular traces to track memory usage.
273   struct DeviceStats {
274     const DeviceStepStats* regular_stats;
275     const DeviceStepStats* hardware_stats;
276   };
277 
278   std::unordered_map<StringPiece, DeviceStats, StringPieceHasher>
279       per_device_stats;
280   std::unordered_map<int, const DeviceStepStats*> gpu_hardware_stats;
281 
282   for (int i = 0; i < step_stats_->dev_stats_size(); ++i) {
283     const DeviceStepStats& device_stats = step_stats_->dev_stats(i);
284     const string& device_name = device_stats.device();
285     const int gpu_id = ExtractGpuWithStreamAll(device_name);
286     if (gpu_id >= 0) {
287       // These are gpu hardware stats
288       gpu_hardware_stats.emplace(gpu_id, &device_stats);
289     } else {
290       // These are regular stats.
291       per_device_stats.emplace(device_name,
292                                DeviceStats{&device_stats, nullptr});
293     }
294   }
295 
296   for (auto& itr : per_device_stats) {
297     const StringPiece device_name = itr.first;
298     const int gpu_id = ExtractGpuWithoutStream(string(device_name));
299     if (gpu_id >= 0) {
300       // Reference the gpu hardware stats in addition to the regular stats
301       // for this gpu device if they're available.
302       if (gpu_hardware_stats.find(gpu_id) != gpu_hardware_stats.end()) {
303         itr.second.hardware_stats = gpu_hardware_stats.find(gpu_id)->second;
304       }
305     }
306   }
307 
308   for (const auto& itr : device_map) {
309     const StringPiece device = itr.first;
310     if (per_device_stats.find(device) == per_device_stats.end()) {
311       continue;
312     }
313 
314     const Graph* graph = itr.second;
315     CostModel* cm = cost_model_manager->FindOrCreateCostModel(graph);
316     cm->IncrementUpdateTimes();
317 
318     std::unordered_map<StringPiece, Node*, StringPieceHasher> name_to_node;
319     for (Node* n : graph->nodes()) {
320       name_to_node.emplace(n->name(), n);
321     }
322 
323     const DeviceStats& dev_stats = per_device_stats.find(device)->second;
324 
325     std::unordered_map<string, NodeExecStats> name_to_hw_node_stats;
326     if (dev_stats.hardware_stats) {
327       for (const auto& node_stats : dev_stats.hardware_stats->node_stats()) {
328         string node_name = node_stats.node_name();
329         // Remove the part of op name (e.g. :Conv2D) in the end of a node name.
330         size_t pos = node_name.find_first_of(':');
331         if (pos != std::string::npos) {
332           node_name = node_name.substr(0, pos);
333         }
334         // Certain ops (e.g. Conv2D) are implemented with multiple GPU kernels,
335         // which results in multiple NodeExecStats with the same node name. For
336         // such ops, we sum up the time for all its GPU kernels.
337         if (name_to_hw_node_stats.find(node_name) !=
338             name_to_hw_node_stats.end()) {
339           int64 time = name_to_hw_node_stats[node_name].op_end_rel_micros();
340           name_to_hw_node_stats[node_name].set_op_end_rel_micros(
341               time + node_stats.op_end_rel_micros());
342         } else {
343           name_to_hw_node_stats.emplace(node_name, node_stats);
344         }
345       }
346     }
347 
348     for (int i = 0; i < dev_stats.regular_stats->node_stats_size(); ++i) {
349       const NodeExecStats& stats = dev_stats.regular_stats->node_stats(i);
350       const Node* node = name_to_node[stats.node_name()];
351       if (node) {
352         for (int i = 0; i < stats.output_size(); ++i) {
353           const auto& output = stats.output(i);
354           int output_slot = output.slot();
355           cm->RecordMaxMemorySize(node, output_slot,
356                                   Bytes(output.tensor_description()
357                                             .allocation_description()
358                                             .allocated_bytes()),
359                                   output.tensor_description().shape(),
360                                   node->output_types()[output_slot]);
361           cm->RecordAllocationId(node, output_slot,
362                                  output.tensor_description()
363                                      .allocation_description()
364                                      .allocation_id());
365         }
366         cm->RecordMemoryStats(node, stats.memory_stats());
367         // Use hardware stats to record the execution time if they're available,
368         // otherwise use the regular (less accurate) stats
369         string node_name = dev_stats.regular_stats->node_stats(i).node_name();
370         if (dev_stats.hardware_stats && name_to_hw_node_stats.find(node_name) !=
371                                             name_to_hw_node_stats.end()) {
372           const NodeExecStats& hw_stats = name_to_hw_node_stats[node_name];
373           cm->RecordMaxExecutionTime(
374               node, Microseconds(hw_stats.op_end_rel_micros()));
375         } else {
376           cm->RecordMaxExecutionTime(node,
377                                      Microseconds(stats.op_end_rel_micros()));
378         }
379       }
380     }
381   }
382 }
383 
Save(const string & device,NodeExecStats * node_stats_pb)384 void StepStatsCollector::Save(const string& device,
385                               NodeExecStats* node_stats_pb) {
386   Save(device,
387        new NodeExecStatsWrapper(std::unique_ptr<NodeExecStats>(node_stats_pb),
388                                 nullptr, this));
389 }
390 
Save(const string & device,NodeExecStatsWrapper * node_stats)391 void StepStatsCollector::Save(const string& device,
392                               NodeExecStatsWrapper* node_stats) {
393   if (!node_stats) return;
394   VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats();
395   {
396     mutex_lock l(mu_);
397     if (finalized_) {
398       LOG(WARNING) << "stats saved after finalize will not be collected.";
399     }
400     if (!step_stats_ || collected_nodes_ >= kMaxCollectedNodes) {
401       VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
402       delete node_stats;
403       return;
404     }
405     auto& device_stats = dev_stats_[device];
406     device_stats.push_back(std::unique_ptr<NodeExecStatsWrapper>(node_stats));
407     collected_nodes_++;
408   }
409 }
410 
SaveThreadName(const string & device,const uint32 thread_id,const string & thread_name)411 void StepStatsCollector::SaveThreadName(const string& device,
412                                         const uint32 thread_id,
413                                         const string& thread_name) {
414   VLOG(1) << "Save dev " << device << " thread id " << thread_id << " name "
415           << thread_name;
416   {
417     mutex_lock l(mu_);
418     if (finalized_) {
419       LOG(WARNING) << "thread_name saved after finalize will not be collected.";
420     }
421     auto& thread_names_map = thread_names_[device];
422     thread_names_map[thread_id] = thread_name;
423   }
424 }
425 
CreateNodeExecStats(const NodeDef * node)426 NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats(
427     const NodeDef* node) {
428   // Only collect statistics for non-transfer nodes.
429   if (IsSend(node) || IsRecv(node)) {
430     return nullptr;
431   }
432   return new NodeExecStatsWrapper(node, this);
433 }
434 
ReportAllocsOnResourceExhausted(const string & err)435 string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) {
436   mutex_lock l(mu_);
437   if (err.find("OOM") == err.npos) {
438     return "";
439   }
440   // <device, allocator> -> AllocStats
441   std::map<std::pair<string, string>, AllocStats> allocs_map;
442   string report = "\n";
443   for (const auto& dev_stat : dev_stats_) {
444     const string& device = dev_stat.first;
445     // Only print the device that has OOM.
446     // TODO(xpan): Extract device from err first to speed it up.
447     if (err.find(device) == err.npos) {
448       continue;
449     }
450     // NodeExecStatsWrapper*
451     for (const auto& stats : dev_stat.second) {
452       // std::pair<AllocatorMemoryUsed*, TrackingAllocator*>
453       for (const auto& alloc : stats->allocations_) {
454         // Only print the allocator that has OOM.
455         // TODO(xpan): Extract device from err first to speed it up.
456         if (err.find(alloc.first->allocator_name()) == err.npos) {
457           continue;
458         }
459         auto dev_allocator =
460             std::make_pair(dev_stat.first, alloc.first->allocator_name());
461         AllocStats& dev_allocs_stats = allocs_map[dev_allocator];
462         TrackingAllocator* tracking_alloc = alloc.second;
463         gtl::InlinedVector<AllocRecord, 4> cur_records =
464             tracking_alloc->GetCurrentRecords();
465         int64 cur_bytes = 0;
466         for (const auto& r : cur_records) {
467           cur_bytes += r.alloc_bytes;
468         }
469         if (cur_bytes > 0) {
470           dev_allocs_stats.total_bytes += cur_bytes;
471           dev_allocs_stats.total_nodes++;
472           dev_allocs_stats.nodes_by_size[cur_bytes].push_back(
473               stats->stats()->node_name());
474         }
475       }
476     }
477   }
478 
479   for (const auto& dev_allocs_it : allocs_map) {
480     const auto& dev = dev_allocs_it.first;
481     const AllocStats& dev_allocs_stats = dev_allocs_it.second;
482     int64 reported_bytes = 0;
483     int64 reported_nodes = 0;
484     bool done = false;
485     strings::StrAppend(&report, "\nCurrent usage from device: ", dev.first,
486                        ", allocator: ", dev.second, "\n");
487     // Print allocations stats of the <device, allocator> pair.
488     for (auto it = dev_allocs_stats.nodes_by_size.rbegin();
489          it != dev_allocs_stats.nodes_by_size.rend(); ++it) {
490       for (const string& node_name : it->second) {
491         reported_bytes += it->first;
492         strings::StrAppend(&report, "  ",
493                            strings::HumanReadableNumBytes(it->first), " from ",
494                            node_name, "\n");
495         if (++reported_nodes > kMaxAllocReportNodes ||
496             reported_bytes >=
497                 dev_allocs_stats.total_bytes * kMaxAllocReportFraction) {
498           done = true;
499           break;
500         }
501       }
502       if (done) break;
503     }
504     int64 remain_nodes = dev_allocs_stats.total_nodes - reported_nodes;
505     int64 remain_bytes = dev_allocs_stats.total_bytes - reported_bytes;
506     if (remain_nodes > 0) {
507       strings::StrAppend(&report, "  Remaining ", remain_nodes, " nodes with ",
508                          strings::HumanReadableNumBytes(remain_bytes), "\n");
509     }
510   }
511   return report;
512 }
513 
Finalize()514 void StepStatsCollector::Finalize() {
515   mutex_lock l(mu_);
516   FinalizeInternal();
517 }
518 
FinalizeAndSwap(StepStats * step_stats)519 void StepStatsCollector::FinalizeAndSwap(StepStats* step_stats) {
520   mutex_lock l(mu_);
521   CHECK(step_stats_);
522   FinalizeInternal();
523   step_stats->Swap(step_stats_);
524   collected_nodes_ = 0;
525 }
526 
FinalizeInternal()527 void StepStatsCollector::FinalizeInternal() {
528   if (!step_stats_ || finalized_) {
529     return;
530   }
531   finalized_ = true;
532   std::map<string, DeviceStepStats*> dev_stats_pb;
533   for (auto& ds : *step_stats_->mutable_dev_stats()) {
534     dev_stats_pb[ds.device()] = &ds;
535   }
536   for (const auto& dev_stat : dev_stats_) {
537     if (dev_stats_pb.find(dev_stat.first) == dev_stats_pb.end()) {
538       DeviceStepStats* ndev_stat = step_stats_->add_dev_stats();
539       ndev_stat->set_device(dev_stat.first);
540       dev_stats_pb[dev_stat.first] = ndev_stat;
541     }
542     DeviceStepStats* dss = dev_stats_pb.at(dev_stat.first);
543     for (auto& stats : dev_stat.second) {
544       stats->Finalize();
545       stats->stats()->Swap(dss->add_node_stats());
546     }
547   }
548   for (const auto& device_thread : thread_names_) {
549     if (dev_stats_pb.find(device_thread.first) == dev_stats_pb.end()) {
550       // skip device without DeviceStepStats.
551       continue;
552     }
553     DeviceStepStats* dss = dev_stats_pb.at(device_thread.first);
554     for (const auto& thread_name : device_thread.second) {
555       (*dss->mutable_thread_names())[thread_name.first] = thread_name.second;
556     }
557   }
558 }
559 }  // namespace tensorflow
560