1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/internal/tfprof_code.h"
17
18 #include <stdio.h>
19
20 #include <utility>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/lib/io/zlib_compression_options.h"
28 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
29 #include "tensorflow/core/platform/regexp.h"
30 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
31
32 namespace tensorflow {
33 namespace tfprof {
34 namespace {
35
36 const char* const kGradientSuffix = " (gradient)";
37
38 // Convert to Trace proto into a short readable string.
GetTraceString(const CallStack::Trace & trace)39 std::string GetTraceString(const CallStack::Trace& trace) {
40 std::string ntrace =
41 absl::StrCat(io::Basename(trace.file()), ":", trace.lineno());
42 if (trace.function().length() < 20) {
43 absl::StrAppend(&ntrace, ":", trace.function());
44 } else {
45 absl::StrAppend(&ntrace, ":", trace.function().substr(0, 17), "...");
46 }
47 return ntrace;
48 }
49
IsGradNode(const string & name,string * forward_name)50 bool IsGradNode(const string& name, string* forward_name) {
51 // Given a forward operation with name op, its gradient op has the following
52 // name: ...gradients/op_grad/...
53 // TODO(xpan): This is hacky.
54 auto grad_prefix = name.find("gradients/");
55 auto grad_suffix = name.find("_grad/");
56 if (grad_prefix == name.npos || grad_suffix == name.npos) {
57 return false;
58 }
59 auto start = grad_prefix + string("gradients/").length();
60 auto len = grad_suffix - start;
61 if (len <= 0) {
62 return false;
63 }
64 *forward_name = name.substr(start, len);
65 return true;
66 }
67
68 // StringTable maps each string to an id.
69 class StringTable {
70 public:
StringTable()71 StringTable() {
72 // Pprof requires first entry in string_table to be ''.
73 string_id_[""] = 0;
74 all_strings_.push_back("");
75 }
76
77 // Returns the index of a string. If not found, inserts the string and
78 // return the inserted index.
GetIndex(const string & str)79 uint64 GetIndex(const string& str) {
80 auto idx = string_id_.find(str);
81 if (idx != string_id_.end()) {
82 return idx->second;
83 }
84 all_strings_.push_back(str);
85 return string_id_.insert(std::pair<string, int64>(str, string_id_.size()))
86 .first->second;
87 }
88
strings() const89 const std::vector<string>& strings() const { return all_strings_; }
90
91 private:
92 std::map<string, uint64> string_id_;
93 std::vector<string> all_strings_;
94 };
95
96 // FunctionTable maps each function to an id.
97 class FunctionTable {
98 public:
FunctionTable(StringTable * string_table)99 explicit FunctionTable(StringTable* string_table)
100 : string_table_(string_table) {}
101
102 // Returns the index of a function. If not found, adds a function proto
103 // and returns the function index.
GetIndex(const string & file_path,const string & func_name,uint64 func_start_line)104 uint64 GetIndex(const string& file_path, const string& func_name,
105 uint64 func_start_line) {
106 auto key = std::tuple<string, string, uint64>(file_path, func_name,
107 func_start_line);
108 auto idx = function_table_.find(key);
109 if (idx != function_table_.end()) {
110 return idx->second.id();
111 }
112 pprof::Function* func_pb = &function_table_[key];
113 // function index should start from 1.
114 func_pb->set_id(function_table_.size());
115
116 string file_base(io::Basename(file_path));
117 file_base = file_base.substr(0, file_base.find_last_of("."));
118 func_pb->set_name(
119 string_table_->GetIndex(absl::StrCat(file_base, ":", func_name)));
120 func_pb->set_filename(string_table_->GetIndex(file_path));
121 func_pb->set_start_line(func_start_line);
122 return func_pb->id();
123 }
124
125 const std::map<std::tuple<string, string, uint64>, pprof::Function>&
functions() const126 functions() const {
127 return function_table_;
128 }
129
130 private:
131 StringTable* string_table_;
132 std::map<std::tuple<string, string, uint64>, pprof::Function> function_table_;
133 };
134
135 // LocationTable maps each function call to an id.
136 class LocationTable {
137 public:
LocationTable(FunctionTable * function_table)138 explicit LocationTable(FunctionTable* function_table)
139 : function_table_(function_table) {}
140
141 // Returns the index of a function call localtion. If not found, adds a
142 // location proto and returns the location index.
GetIndex(const string & file_path,uint64 line_number,const string & called_function_name,const string & called_file_path,uint64 called_func_start_line)143 uint64 GetIndex(const string& file_path, uint64 line_number,
144 const string& called_function_name,
145 const string& called_file_path,
146 uint64 called_func_start_line) {
147 auto key = std::tuple<string, string, uint64>(
148 file_path, called_function_name, line_number);
149
150 auto idx = location_table_.find(key);
151 if (idx != location_table_.end()) {
152 return idx->second.id();
153 }
154 pprof::Location* location_pb = &location_table_[key];
155 location_pb->set_id(location_table_.size());
156 pprof::Line* line_pb = location_pb->add_line();
157 line_pb->set_function_id(function_table_->GetIndex(
158 called_file_path, called_function_name, called_func_start_line));
159 line_pb->set_line(line_number);
160 return location_pb->id();
161 }
162
163 const std::map<std::tuple<string, string, uint64>, pprof::Location>&
locations() const164 locations() const {
165 return location_table_;
166 }
167
168 private:
169 FunctionTable* function_table_;
170 std::map<std::tuple<string, string, uint64>, pprof::Location> location_table_;
171 };
172
173 // Samples stores samples of all calls. A sample is a single call trace,
174 // that is, the call path from top caller to the leaf callee.
175 class Samples {
176 public:
Samples(StringTable * string_table,const Options * opts)177 explicit Samples(StringTable* string_table, const Options* opts)
178 : string_table_(string_table), opts_(opts) {}
179
180 // 'node' is the leaf of the displayed trace. It includes all graph nodes
181 // created by it. 'location_ids' contains
182 // the call stack, from callee to caller.
183 // This method adds the statistics of graph nodes created by the python
184 // call.
Add(const CodeNode * node,const std::vector<uint64> & location_ids)185 void Add(const CodeNode* node, const std::vector<uint64>& location_ids) {
186 // displayed leaf might not be true leaf. Retrieve the true leaves for
187 // stats.
188 std::vector<const CodeNode*> all_leaf = FetchAllLeaf(node);
189 CHECK(!all_leaf.empty()) << node->name();
190
191 for (const CodeNode* cn : all_leaf) {
192 for (auto gn_it : cn->node->graph_nodes()) {
193 const TFGraphNode* gn = gn_it.second;
194 string name = gn->name();
195 // Generate a new trace name, in case the name is taken.
196 while (sample_table_.find(name) != sample_table_.end()) {
197 name += '@';
198 }
199 pprof::Sample* sample_pb = &sample_table_[name];
200 for (uint64 id : location_ids) {
201 sample_pb->mutable_location_id()->Add(id);
202 }
203 pprof::Label* label_pb = sample_pb->mutable_label()->Add();
204 label_pb->set_key(string_table_->GetIndex("graph node:"));
205 label_pb->set_str(string_table_->GetIndex(gn->name()));
206
207 sample_pb->mutable_value()->Add(1);
208 string type = *opts_->select.begin();
209 if (type == kShown[1]) {
210 sample_pb->mutable_value()->Add(gn->exec_micros(node->node->step()));
211 } else if (type == kShown[9]) {
212 sample_pb->mutable_value()->Add(
213 gn->accelerator_exec_micros(node->node->step()));
214 } else if (type == kShown[10]) {
215 sample_pb->mutable_value()->Add(
216 gn->cpu_exec_micros(node->node->step()));
217 } else if (type == kShown[0]) {
218 sample_pb->mutable_value()->Add(
219 gn->requested_bytes(node->node->step()));
220 } else if (type == kShown[11]) {
221 sample_pb->mutable_value()->Add(gn->peak_bytes(node->node->step()));
222 } else if (type == kShown[12]) {
223 sample_pb->mutable_value()->Add(
224 gn->residual_bytes(node->node->step()));
225 } else if (type == kShown[13]) {
226 sample_pb->mutable_value()->Add(gn->output_bytes(node->node->step()));
227 } else if (type == kShown[2]) {
228 sample_pb->mutable_value()->Add(gn->parameters());
229 } else if (type == kShown[3]) {
230 sample_pb->mutable_value()->Add(gn->float_ops(node->node->step()));
231 } else {
232 absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", type);
233 }
234 }
235 }
236 }
237
samples() const238 const std::map<string, pprof::Sample>& samples() const {
239 return sample_table_;
240 }
241
242 private:
FetchAllLeaf(const CodeNode * root)243 std::vector<const CodeNode*> FetchAllLeaf(const CodeNode* root) {
244 if (root->children.empty()) {
245 return {root};
246 }
247 std::vector<const CodeNode*> ret;
248 for (auto& n : root->children) {
249 std::vector<const CodeNode*> nodes = FetchAllLeaf(n);
250 ret.insert(ret.end(), nodes.begin(), nodes.end());
251 }
252 return ret;
253 }
254
255 StringTable* string_table_;
256 const Options* opts_;
257 std::map<string, pprof::Sample> sample_table_;
258 };
259
260 class PprofProfileImpl : public PprofProfile {
261 public:
PprofProfileImpl(const Options * opts)262 explicit PprofProfileImpl(const Options* opts)
263 : opts_(opts),
264 func_table_(new FunctionTable(&string_table_)),
265 loc_table_(new LocationTable(func_table_.get())),
266 samples_(new Samples(&string_table_, opts)) {}
267
AddLocation(const CodeNode * callee,const CodeNode * caller)268 uint64 AddLocation(const CodeNode* callee, const CodeNode* caller) override {
269 const string& file_path = caller->file();
270 uint64 lineno = caller->lineno();
271 const string& callee_file_path = callee->file();
272 const string& callee_function = callee->function();
273 uint64 callee_func_start_line = callee->func_start_line();
274
275 return loc_table_->GetIndex(file_path, lineno, callee_function,
276 callee_file_path, callee_func_start_line);
277 }
278
AddSample(const CodeNode * leaf,std::vector<uint64> * call_ids)279 void AddSample(const CodeNode* leaf, std::vector<uint64>* call_ids) override {
280 std::vector<uint64> reversed_call_ids;
281 std::reverse_copy(call_ids->begin(), call_ids->end(),
282 std::back_inserter(reversed_call_ids));
283 samples_->Add(leaf, reversed_call_ids);
284 }
285
WritePprofProfile(const string & filename)286 Status WritePprofProfile(const string& filename) override {
287 pprof::Profile profile_pb;
288 Build(&profile_pb);
289
290 std::unique_ptr<WritableFile> file;
291 Status s = Env::Default()->NewWritableFile(filename, &file);
292 if (!s.ok()) return s;
293
294 int32 buf_size = 1024 * 1024;
295 io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
296 file.get(), buf_size, buf_size, io::ZlibCompressionOptions::GZIP());
297 s = zlib_output_buffer->Init();
298 if (!s.ok()) {
299 delete zlib_output_buffer;
300 return s;
301 }
302 s = zlib_output_buffer->Append(profile_pb.SerializeAsString());
303 if (!s.ok()) {
304 delete zlib_output_buffer;
305 return s;
306 }
307 s = zlib_output_buffer->Close();
308 if (!s.ok()) {
309 delete zlib_output_buffer;
310 return s;
311 }
312 absl::FPrintF(stdout,
313 "\nRun pprof -png --nodecount=100 --sample_index=1 <%s>\n",
314 filename);
315 delete zlib_output_buffer;
316 return s;
317 }
318
319 private:
Build(pprof::Profile * profile_pb)320 void Build(pprof::Profile* profile_pb) {
321 string sample_type_description = "count";
322 auto sample_type = profile_pb->mutable_sample_type()->Add();
323 sample_type->set_type(string_table_.GetIndex(sample_type_description));
324 sample_type->set_unit(string_table_.GetIndex("count"));
325
326 string type = *opts_->select.begin();
327 sample_type_description = type;
328 sample_type = profile_pb->mutable_sample_type()->Add();
329 sample_type->set_type(string_table_.GetIndex(sample_type_description));
330 if (type == kShown[1] || type == kShown[9] || type == kShown[10]) {
331 sample_type->set_unit(string_table_.GetIndex("microseconds"));
332 if (type == kShown[1]) {
333 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
334 "Sum of accelerator execution time and cpu execution time."));
335 } else if (type == kShown[9]) {
336 profile_pb->mutable_comment()->Add(
337 string_table_.GetIndex("Accelerator execution time."));
338 } else if (type == kShown[10]) {
339 profile_pb->mutable_comment()->Add(
340 string_table_.GetIndex("CPU execution time."));
341 }
342 } else if (type == kShown[0]) {
343 sample_type->set_unit(string_table_.GetIndex("bytes"));
344 profile_pb->mutable_comment()->Add(
345 string_table_.GetIndex("Sum of operation total memory requests, "
346 "excluding deallocations."));
347 } else if (type == kShown[11]) {
348 sample_type->set_unit(string_table_.GetIndex("bytes"));
349 profile_pb->mutable_comment()->Add(
350 string_table_.GetIndex("Sum of operation peak memory usage."));
351 } else if (type == kShown[12]) {
352 sample_type->set_unit(string_table_.GetIndex("bytes"));
353 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
354 "Sum of operation allocated memory after finish."));
355 } else if (type == kShown[13]) {
356 sample_type->set_unit(string_table_.GetIndex("bytes"));
357 profile_pb->mutable_comment()->Add(
358 string_table_.GetIndex("Sum of operation output size."));
359 } else if (type == kShown[2]) {
360 sample_type->set_unit(string_table_.GetIndex("count"));
361 profile_pb->mutable_comment()->Add(
362 string_table_.GetIndex("Model parameters."));
363 } else if (type == kShown[3]) {
364 sample_type->set_unit(string_table_.GetIndex("count"));
365 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
366 "Model float operations (Only available if defined)."));
367 } else {
368 absl::FPrintF(stderr, "pprof doesn't support selecting: %s\n", type);
369 }
370
371 for (const string& str : string_table_.strings()) {
372 *profile_pb->mutable_string_table()->Add() = str;
373 }
374 for (const auto& sample_it : samples_->samples()) {
375 // TODO(xpan): Consider swap.
376 profile_pb->mutable_sample()->Add()->MergeFrom(sample_it.second);
377 }
378 for (const auto& function_it : func_table_->functions()) {
379 profile_pb->mutable_function()->Add()->MergeFrom(function_it.second);
380 }
381 for (const auto& location_it : loc_table_->locations()) {
382 profile_pb->mutable_location()->Add()->MergeFrom(location_it.second);
383 }
384 }
385
386 const Options* opts_;
387 StringTable string_table_;
388 std::unique_ptr<FunctionTable> func_table_;
389 std::unique_ptr<LocationTable> loc_table_;
390 std::unique_ptr<Samples> samples_;
391 };
392 } // namespace
393
AddNode(TFGraphNode * node)394 void TFCode::AddNode(TFGraphNode* node) {
395 if (!node->call_stack() || node->call_stack()->traces().empty()) {
396 return;
397 }
398 // We infer the forward operation name from gradient op name. So, we can
399 // map gradient op traces to forward op traces.
400 // E.g. gradient node of 'inp_1/Conv2D' would be 'gradients/inp_1/Conv2D_grad.
401 string forward_name;
402 if (IsGradNode(node->name(), &forward_name)) {
403 auto grad_nodes_it = grad_nodes_.find(forward_name);
404 if (grad_nodes_it != grad_nodes_.end()) {
405 grad_nodes_it->second.push_back(node);
406 } else {
407 grad_nodes_.insert(
408 std::pair<string, std::vector<TFGraphNode*>>(forward_name, {node}));
409 }
410 return;
411 } else {
412 forward_nodes_[node->name()] = node;
413 }
414
415 if (!root_) {
416 graph_root_.reset(new TFMultiGraphNode(kTFProfRoot));
417 root_.reset(new CodeNode(graph_root_.get(), nullptr, ""));
418 }
419
420 CodeNode* pre_code_node = root_.get();
421 // TODO(xpan): Consider to release CodeDef after TFCode is built. It
422 // takes a lot of memory.
423 std::set<string> traces;
424 for (int i = 0; i < node->call_stack()->traces().size(); ++i) {
425 // Unlike op name, which is globally unique, trace name is only unique
426 // w.r.t. it's parent.
427 const string& trace = GetTraceString(node->call_stack()->traces().at(i));
428 traces.insert(trace);
429 pre_code_node = pre_code_node->AddChildren(
430 trace, &node->call_stack()->traces().at(i), "");
431 if (i == node->call_stack()->traces().size() - 1) {
432 pre_code_node->node->AddGraphNode(node);
433 }
434 }
435 }
436
Build()437 void TFCode::Build() {
438 int64 unaccounted_nodes = 0;
439 for (auto it : grad_nodes_) {
440 const string& forward_name = it.first;
441 auto forward_it = forward_nodes_.find(forward_name);
442 if (forward_it == forward_nodes_.end()) {
443 unaccounted_nodes += 1;
444 continue;
445 }
446 TFGraphNode* fn = forward_it->second;
447 CodeNode* leaf = nullptr;
448 CodeNode* pre_code_node = root_.get();
449 for (int i = 0; i < fn->call_stack()->traces().size(); ++i) {
450 const string& trace =
451 GetTraceString(fn->call_stack()->traces().at(i)) + kGradientSuffix;
452 pre_code_node = pre_code_node->AddChildren(
453 trace, &fn->call_stack()->traces().at(i), kGradientSuffix);
454 if (i == fn->call_stack()->traces().size() - 1) {
455 leaf = pre_code_node;
456 }
457 }
458 for (TFGraphNode* gn : it.second) {
459 leaf->node->AddGraphNode(gn);
460 }
461 }
462 if (unaccounted_nodes > 0) {
463 absl::FPrintF(stderr, "%d gradient nodes not accounted\n",
464 unaccounted_nodes);
465 }
466 }
467
ShowInternal(const Options & opts,Timeline * timeline)468 const ShowMultiNode* TFCode::ShowInternal(const Options& opts,
469 Timeline* timeline) {
470 root_->ResetTotalStats();
471 if (opts.output_type == kOutput[3]) {
472 if (opts.select.size() != 1) {
473 absl::FPrintF(stderr, "Can only select 1 attribute for pprof output.\n");
474 return root_.get();
475 }
476 string select = *opts.select.begin();
477 if (select != kShown[0] && select != kShown[1] && select != kShown[2] &&
478 select != kShown[3] && select != kShown[9] && select != kShown[10] &&
479 select != kShown[11] && select != kShown[12] && select != kShown[13]) {
480 absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", select);
481 return root_.get();
482 }
483 }
484 if (opts.account_displayed_op_only) {
485 absl::FPrintF(stderr,
486 "Note: code view ignores account_displayed_op_only\n");
487 }
488
489 std::vector<CodeNode*> roots = Account(root_->children, opts);
490 root_->show_children.clear();
491 for (CodeNode* n : roots) {
492 root_->AggregateTotalStats(n);
493 }
494
495 if (opts.start_name_regexes.size() != 1 ||
496 opts.start_name_regexes[0] != ".*") {
497 roots = SearchRoot(roots, opts.start_name_regexes);
498 }
499
500 root_->show_children.assign(roots.begin(), roots.end());
501
502 CodeNode* root = PrintScope({root_.get()}, opts, 1, 0)[0];
503
504 root->formatted_str = FormatLegend(opts) + root->formatted_str;
505
506 if (opts.output_type == kOutput[3]) {
507 std::vector<uint64> call_ids;
508 pprof_profile_.reset(new PprofProfileImpl(&opts));
509 Format(root, root->show_children, opts, &root->formatted_str,
510 root->mutable_proto(), &call_ids);
511 Status s = pprof_profile_->WritePprofProfile(
512 opts.output_options.at(kPprofOpts[0]));
513 if (!s.ok()) {
514 absl::FPrintF(stderr, "%s\n", s.ToString());
515 }
516 } else {
517 Format(root, root->show_children, opts, &root->formatted_str,
518 root->mutable_proto(), nullptr);
519 if (timeline) {
520 timeline->GenerateCodeTimeline(root);
521 }
522 }
523 return root;
524 }
525
Format(const CodeNode * root,const std::vector<CodeNode * > & nodes,const Options & opts,string * display_str,MultiGraphNodeProto * proto,std::vector<uint64> * call_ids)526 void TFCode::Format(const CodeNode* root, const std::vector<CodeNode*>& nodes,
527 const Options& opts, string* display_str,
528 MultiGraphNodeProto* proto, std::vector<uint64>* call_ids) {
529 if (nodes.empty() && root->has_trace() && opts.output_type == kOutput[3]) {
530 pprof_profile_->AddSample(root, call_ids);
531 }
532
533 for (CodeNode* node : nodes) {
534 if (root->has_trace() && opts.output_type == kOutput[3]) {
535 uint64 loc_id = pprof_profile_->AddLocation(node, root);
536 call_ids->push_back(loc_id);
537 }
538 display_str->append(node->formatted_str);
539 MultiGraphNodeProto* child = proto->add_children();
540 child->MergeFrom(node->proto());
541 Format(node, node->show_children, opts, display_str, child, call_ids);
542 if (root->has_trace() && opts.output_type == kOutput[3]) {
543 call_ids->pop_back();
544 }
545 }
546 }
547
SearchRoot(std::vector<CodeNode * > roots,const std::vector<string> & regexes)548 std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
549 const std::vector<string>& regexes) {
550 std::vector<CodeNode*> res;
551 if (roots.empty()) {
552 return res;
553 }
554 for (CodeNode* root : roots) {
555 bool match_start_node = false;
556 for (const string& regex : regexes) {
557 if (RE2::FullMatch(root->name(), regex)) {
558 res.push_back(root);
559 match_start_node = true;
560 break;
561 }
562 }
563 if (match_start_node) {
564 // Found a start node at this branch, no need to continue.
565 continue;
566 }
567 std::vector<CodeNode*> nroots = SearchRoot(root->show_children, regexes);
568 res.insert(res.end(), nroots.begin(), nroots.end());
569 }
570 return res;
571 }
572
PrintScope(const std::vector<CodeNode * > roots,const Options & opts,int depth,int last_ident)573 std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
574 const Options& opts, int depth,
575 int last_ident) {
576 std::vector<CodeNode*> show_nodes;
577
578 for (CodeNode* node : roots) {
579 if (ShouldTrim(node, opts.trim_name_regexes) || depth > opts.max_depth) {
580 continue;
581 }
582 int ident = last_ident;
583 bool show = ShouldShow(node, opts, depth);
584 if (show) ident += 2;
585
586 std::vector<CodeNode*> show_cnodes =
587 PrintScope(node->show_children, opts, depth + 1, ident);
588 if (show) {
589 node->show_children.clear();
590
591 show_cnodes = SortNodes(show_cnodes, opts);
592 for (CodeNode* sc : show_cnodes) {
593 node->show_children.push_back(sc);
594 }
595
596 node->formatted_str = FormatNode(node, opts, last_ident);
597
598 if (opts.select.find(kShown[4]) != opts.select.end()) {
599 absl::FPrintF(stderr, "code view has no tensor value to show\n");
600 }
601 show_nodes.push_back(node);
602 } else {
603 show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
604 show_cnodes.end());
605 }
606 }
607 return show_nodes;
608 }
609
Account(const std::vector<CodeNode * > & roots,const Options & opts)610 std::vector<CodeNode*> TFCode::Account(const std::vector<CodeNode*>& roots,
611 const Options& opts) {
612 std::vector<CodeNode*> act_nodes;
613
614 for (CodeNode* node : roots) {
615 node->ResetTotalStats();
616 std::vector<CodeNode*> act_cnodes = Account(node->children, opts);
617 node->account = ReAccount(node, opts);
618 if (node->account || !act_cnodes.empty()) {
619 node->show_children.clear();
620 node->ResetTotalStats();
621 node->AddSelfToTotalStats();
622 for (CodeNode* c : act_cnodes) {
623 node->AggregateTotalStats(c);
624 node->show_children.push_back(c);
625 }
626 act_nodes.push_back(node);
627 }
628 }
629 return act_nodes;
630 }
631
FormatNodeMemory(CodeNode * node,int64 bytes,int64 total_bytes) const632 string TFCode::FormatNodeMemory(CodeNode* node, int64 bytes,
633 int64 total_bytes) const {
634 string memory = FormatMemory(total_bytes);
635 if (node->account) {
636 memory = FormatMemory(bytes) + "/" + memory;
637 } else {
638 memory = "--/" + memory;
639 }
640 return memory;
641 }
642
FormatNode(CodeNode * node,const Options & opts,int64 indent) const643 string TFCode::FormatNode(CodeNode* node, const Options& opts,
644 int64 indent) const {
645 std::vector<string> attrs;
646 if (opts.select.find(kShown[0]) != opts.select.end()) {
647 attrs.push_back(FormatNodeMemory(node, node->proto().requested_bytes(),
648 node->proto().total_requested_bytes()));
649 }
650 if (opts.select.find(kShown[11]) != opts.select.end()) {
651 attrs.push_back(FormatNodeMemory(node, node->proto().peak_bytes(),
652 node->proto().total_peak_bytes()));
653 }
654 if (opts.select.find(kShown[12]) != opts.select.end()) {
655 attrs.push_back(FormatNodeMemory(node, node->proto().residual_bytes(),
656 node->proto().total_residual_bytes()));
657 }
658 if (opts.select.find(kShown[13]) != opts.select.end()) {
659 attrs.push_back(FormatNodeMemory(node, node->proto().output_bytes(),
660 node->proto().total_output_bytes()));
661 }
662
663 std::vector<string> time_attrs = FormatTimes(node, opts);
664 attrs.insert(attrs.end(), time_attrs.begin(), time_attrs.end());
665
666 if (opts.select.find(kShown[2]) != opts.select.end()) {
667 string params = FormatNumber(node->proto().total_parameters()) + " params";
668 if (node->account) {
669 params = FormatNumber(node->proto().parameters()) + "/" + params;
670 } else {
671 params = "--/" + params;
672 }
673 attrs.push_back(params);
674 }
675
676 if (opts.select.find(kShown[3]) != opts.select.end()) {
677 string fops = FormatNumber(node->proto().total_float_ops()) + " flops";
678 if (node->account) {
679 fops = FormatNumber(node->proto().float_ops()) + "/" + fops;
680 } else {
681 fops = "--/" + fops;
682 }
683 attrs.push_back(fops);
684 }
685
686 if (opts.select.find(kShown[5]) != opts.select.end() &&
687 !node->node->devices().empty()) {
688 attrs.push_back(absl::StrJoin(node->node->devices(), "|"));
689 }
690 if (opts.select.find(kShown[6]) != opts.select.end()) {
691 std::set<string> op_types = node->node->op_types();
692 attrs.push_back(absl::StrJoin(op_types, "|"));
693 }
694 if (opts.select.find(kShown[7]) != opts.select.end()) {
695 // TODO(xpan): Make op count available in code view?
696 attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[7]));
697 }
698 if (opts.select.find(kShown[8]) != opts.select.end()) {
699 attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[8]));
700 }
701
702 return absl::StrFormat("%s%s (%s)\n", std::string(indent, ' '), node->name(),
703 absl::StrJoin(attrs, ", "));
704 }
705 } // namespace tfprof
706 } // namespace tensorflow
707