• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/internal/tfprof_code.h"
17 
18 #include <stdio.h>
19 
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/lib/io/zlib_compression_options.h"
28 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
29 #include "tensorflow/core/platform/regexp.h"
30 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
31 
32 namespace tensorflow {
33 namespace tfprof {
34 namespace {
35 
36 const char* const kGradientSuffix = " (gradient)";
37 
38 // Convert to Trace proto into a short readable string.
GetTraceString(const CallStack::Trace & trace)39 std::string GetTraceString(const CallStack::Trace& trace) {
40   std::string ntrace =
41       absl::StrCat(io::Basename(trace.file()), ":", trace.lineno());
42   if (trace.function().length() < 20) {
43     absl::StrAppend(&ntrace, ":", trace.function());
44   } else {
45     absl::StrAppend(&ntrace, ":", trace.function().substr(0, 17), "...");
46   }
47   return ntrace;
48 }
49 
IsGradNode(const string & name,string * forward_name)50 bool IsGradNode(const string& name, string* forward_name) {
51   // Given a forward operation with name op, its gradient op has the following
52   // name: ...gradients/op_grad/...
53   // TODO(xpan): This is hacky.
54   auto grad_prefix = name.find("gradients/");
55   auto grad_suffix = name.find("_grad/");
56   if (grad_prefix == name.npos || grad_suffix == name.npos) {
57     return false;
58   }
59   auto start = grad_prefix + string("gradients/").length();
60   auto len = grad_suffix - start;
61   if (len <= 0) {
62     return false;
63   }
64   *forward_name = name.substr(start, len);
65   return true;
66 }
67 
68 // StringTable maps each string to an id.
69 class StringTable {
70  public:
StringTable()71   StringTable() {
72     // Pprof requires first entry in string_table to be ''.
73     string_id_[""] = 0;
74     all_strings_.push_back("");
75   }
76 
77   // Returns the index of a string. If not found, inserts the string and
78   // return the inserted index.
GetIndex(const string & str)79   uint64 GetIndex(const string& str) {
80     auto idx = string_id_.find(str);
81     if (idx != string_id_.end()) {
82       return idx->second;
83     }
84     all_strings_.push_back(str);
85     return string_id_.insert(std::pair<string, int64>(str, string_id_.size()))
86         .first->second;
87   }
88 
strings() const89   const std::vector<string>& strings() const { return all_strings_; }
90 
91  private:
92   std::map<string, uint64> string_id_;
93   std::vector<string> all_strings_;
94 };
95 
96 // FunctionTable maps each function to an id.
97 class FunctionTable {
98  public:
FunctionTable(StringTable * string_table)99   explicit FunctionTable(StringTable* string_table)
100       : string_table_(string_table) {}
101 
102   // Returns the index of a function. If not found, adds a function proto
103   // and returns the function index.
GetIndex(const string & file_path,const string & func_name,uint64 func_start_line)104   uint64 GetIndex(const string& file_path, const string& func_name,
105                   uint64 func_start_line) {
106     auto key = std::tuple<string, string, uint64>(file_path, func_name,
107                                                   func_start_line);
108     auto idx = function_table_.find(key);
109     if (idx != function_table_.end()) {
110       return idx->second.id();
111     }
112     pprof::Function* func_pb = &function_table_[key];
113     // function index should start from 1.
114     func_pb->set_id(function_table_.size());
115 
116     string file_base(io::Basename(file_path));
117     file_base = file_base.substr(0, file_base.find_last_of("."));
118     func_pb->set_name(
119         string_table_->GetIndex(absl::StrCat(file_base, ":", func_name)));
120     func_pb->set_filename(string_table_->GetIndex(file_path));
121     func_pb->set_start_line(func_start_line);
122     return func_pb->id();
123   }
124 
125   const std::map<std::tuple<string, string, uint64>, pprof::Function>&
functions() const126   functions() const {
127     return function_table_;
128   }
129 
130  private:
131   StringTable* string_table_;
132   std::map<std::tuple<string, string, uint64>, pprof::Function> function_table_;
133 };
134 
135 // LocationTable maps each function call to an id.
136 class LocationTable {
137  public:
LocationTable(FunctionTable * function_table)138   explicit LocationTable(FunctionTable* function_table)
139       : function_table_(function_table) {}
140 
141   // Returns the index of a function call localtion. If not found, adds a
142   // location proto and returns the location index.
GetIndex(const string & file_path,uint64 line_number,const string & called_function_name,const string & called_file_path,uint64 called_func_start_line)143   uint64 GetIndex(const string& file_path, uint64 line_number,
144                   const string& called_function_name,
145                   const string& called_file_path,
146                   uint64 called_func_start_line) {
147     auto key = std::tuple<string, string, uint64>(
148         file_path, called_function_name, line_number);
149 
150     auto idx = location_table_.find(key);
151     if (idx != location_table_.end()) {
152       return idx->second.id();
153     }
154     pprof::Location* location_pb = &location_table_[key];
155     location_pb->set_id(location_table_.size());
156     pprof::Line* line_pb = location_pb->add_line();
157     line_pb->set_function_id(function_table_->GetIndex(
158         called_file_path, called_function_name, called_func_start_line));
159     line_pb->set_line(line_number);
160     return location_pb->id();
161   }
162 
163   const std::map<std::tuple<string, string, uint64>, pprof::Location>&
locations() const164   locations() const {
165     return location_table_;
166   }
167 
168  private:
169   FunctionTable* function_table_;
170   std::map<std::tuple<string, string, uint64>, pprof::Location> location_table_;
171 };
172 
173 // Samples stores samples of all calls. A sample is a single call trace,
174 // that is, the call path from top caller to the leaf callee.
175 class Samples {
176  public:
Samples(StringTable * string_table,const Options * opts)177   explicit Samples(StringTable* string_table, const Options* opts)
178       : string_table_(string_table), opts_(opts) {}
179 
180   // 'node' is the leaf of the displayed trace. It includes all graph nodes
181   // created by it. 'location_ids' contains
182   // the call stack, from callee to caller.
183   // This method adds the statistics of graph nodes created by the python
184   // call.
Add(const CodeNode * node,const std::vector<uint64> & location_ids)185   void Add(const CodeNode* node, const std::vector<uint64>& location_ids) {
186     // displayed leaf might not be true leaf. Retrieve the true leaves for
187     // stats.
188     std::vector<const CodeNode*> all_leaf = FetchAllLeaf(node);
189     CHECK(!all_leaf.empty()) << node->name();
190 
191     for (const CodeNode* cn : all_leaf) {
192       for (auto gn_it : cn->node->graph_nodes()) {
193         const TFGraphNode* gn = gn_it.second;
194         string name = gn->name();
195         // Generate a new trace name, in case the name is taken.
196         while (sample_table_.find(name) != sample_table_.end()) {
197           name += '@';
198         }
199         pprof::Sample* sample_pb = &sample_table_[name];
200         for (uint64 id : location_ids) {
201           sample_pb->mutable_location_id()->Add(id);
202         }
203         pprof::Label* label_pb = sample_pb->mutable_label()->Add();
204         label_pb->set_key(string_table_->GetIndex("graph node:"));
205         label_pb->set_str(string_table_->GetIndex(gn->name()));
206 
207         sample_pb->mutable_value()->Add(1);
208         string type = *opts_->select.begin();
209         if (type == kShown[1]) {
210           sample_pb->mutable_value()->Add(gn->exec_micros(node->node->step()));
211         } else if (type == kShown[9]) {
212           sample_pb->mutable_value()->Add(
213               gn->accelerator_exec_micros(node->node->step()));
214         } else if (type == kShown[10]) {
215           sample_pb->mutable_value()->Add(
216               gn->cpu_exec_micros(node->node->step()));
217         } else if (type == kShown[0]) {
218           sample_pb->mutable_value()->Add(
219               gn->requested_bytes(node->node->step()));
220         } else if (type == kShown[11]) {
221           sample_pb->mutable_value()->Add(gn->peak_bytes(node->node->step()));
222         } else if (type == kShown[12]) {
223           sample_pb->mutable_value()->Add(
224               gn->residual_bytes(node->node->step()));
225         } else if (type == kShown[13]) {
226           sample_pb->mutable_value()->Add(gn->output_bytes(node->node->step()));
227         } else if (type == kShown[2]) {
228           sample_pb->mutable_value()->Add(gn->parameters());
229         } else if (type == kShown[3]) {
230           sample_pb->mutable_value()->Add(gn->float_ops(node->node->step()));
231         } else {
232           absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", type);
233         }
234       }
235     }
236   }
237 
samples() const238   const std::map<string, pprof::Sample>& samples() const {
239     return sample_table_;
240   }
241 
242  private:
FetchAllLeaf(const CodeNode * root)243   std::vector<const CodeNode*> FetchAllLeaf(const CodeNode* root) {
244     if (root->children.empty()) {
245       return {root};
246     }
247     std::vector<const CodeNode*> ret;
248     for (auto& n : root->children) {
249       std::vector<const CodeNode*> nodes = FetchAllLeaf(n);
250       ret.insert(ret.end(), nodes.begin(), nodes.end());
251     }
252     return ret;
253   }
254 
255   StringTable* string_table_;
256   const Options* opts_;
257   std::map<string, pprof::Sample> sample_table_;
258 };
259 
260 class PprofProfileImpl : public PprofProfile {
261  public:
PprofProfileImpl(const Options * opts)262   explicit PprofProfileImpl(const Options* opts)
263       : opts_(opts),
264         func_table_(new FunctionTable(&string_table_)),
265         loc_table_(new LocationTable(func_table_.get())),
266         samples_(new Samples(&string_table_, opts)) {}
267 
AddLocation(const CodeNode * callee,const CodeNode * caller)268   uint64 AddLocation(const CodeNode* callee, const CodeNode* caller) override {
269     const string& file_path = caller->file();
270     uint64 lineno = caller->lineno();
271     const string& callee_file_path = callee->file();
272     const string& callee_function = callee->function();
273     uint64 callee_func_start_line = callee->func_start_line();
274 
275     return loc_table_->GetIndex(file_path, lineno, callee_function,
276                                 callee_file_path, callee_func_start_line);
277   }
278 
AddSample(const CodeNode * leaf,std::vector<uint64> * call_ids)279   void AddSample(const CodeNode* leaf, std::vector<uint64>* call_ids) override {
280     std::vector<uint64> reversed_call_ids;
281     std::reverse_copy(call_ids->begin(), call_ids->end(),
282                       std::back_inserter(reversed_call_ids));
283     samples_->Add(leaf, reversed_call_ids);
284   }
285 
WritePprofProfile(const string & filename)286   Status WritePprofProfile(const string& filename) override {
287     pprof::Profile profile_pb;
288     Build(&profile_pb);
289 
290     std::unique_ptr<WritableFile> file;
291     Status s = Env::Default()->NewWritableFile(filename, &file);
292     if (!s.ok()) return s;
293 
294     int32 buf_size = 1024 * 1024;
295     io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
296         file.get(), buf_size, buf_size, io::ZlibCompressionOptions::GZIP());
297     s = zlib_output_buffer->Init();
298     if (!s.ok()) {
299       delete zlib_output_buffer;
300       return s;
301     }
302     s = zlib_output_buffer->Append(profile_pb.SerializeAsString());
303     if (!s.ok()) {
304       delete zlib_output_buffer;
305       return s;
306     }
307     s = zlib_output_buffer->Close();
308     if (!s.ok()) {
309       delete zlib_output_buffer;
310       return s;
311     }
312     absl::FPrintF(stdout,
313                   "\nRun pprof -png --nodecount=100 --sample_index=1 <%s>\n",
314                   filename);
315     delete zlib_output_buffer;
316     return s;
317   }
318 
319  private:
Build(pprof::Profile * profile_pb)320   void Build(pprof::Profile* profile_pb) {
321     string sample_type_description = "count";
322     auto sample_type = profile_pb->mutable_sample_type()->Add();
323     sample_type->set_type(string_table_.GetIndex(sample_type_description));
324     sample_type->set_unit(string_table_.GetIndex("count"));
325 
326     string type = *opts_->select.begin();
327     sample_type_description = type;
328     sample_type = profile_pb->mutable_sample_type()->Add();
329     sample_type->set_type(string_table_.GetIndex(sample_type_description));
330     if (type == kShown[1] || type == kShown[9] || type == kShown[10]) {
331       sample_type->set_unit(string_table_.GetIndex("microseconds"));
332       if (type == kShown[1]) {
333         profile_pb->mutable_comment()->Add(string_table_.GetIndex(
334             "Sum of accelerator execution time and cpu execution time."));
335       } else if (type == kShown[9]) {
336         profile_pb->mutable_comment()->Add(
337             string_table_.GetIndex("Accelerator execution time."));
338       } else if (type == kShown[10]) {
339         profile_pb->mutable_comment()->Add(
340             string_table_.GetIndex("CPU execution time."));
341       }
342     } else if (type == kShown[0]) {
343       sample_type->set_unit(string_table_.GetIndex("bytes"));
344       profile_pb->mutable_comment()->Add(
345           string_table_.GetIndex("Sum of operation total memory requests, "
346                                  "excluding deallocations."));
347     } else if (type == kShown[11]) {
348       sample_type->set_unit(string_table_.GetIndex("bytes"));
349       profile_pb->mutable_comment()->Add(
350           string_table_.GetIndex("Sum of operation peak memory usage."));
351     } else if (type == kShown[12]) {
352       sample_type->set_unit(string_table_.GetIndex("bytes"));
353       profile_pb->mutable_comment()->Add(string_table_.GetIndex(
354           "Sum of operation allocated memory after finish."));
355     } else if (type == kShown[13]) {
356       sample_type->set_unit(string_table_.GetIndex("bytes"));
357       profile_pb->mutable_comment()->Add(
358           string_table_.GetIndex("Sum of operation output size."));
359     } else if (type == kShown[2]) {
360       sample_type->set_unit(string_table_.GetIndex("count"));
361       profile_pb->mutable_comment()->Add(
362           string_table_.GetIndex("Model parameters."));
363     } else if (type == kShown[3]) {
364       sample_type->set_unit(string_table_.GetIndex("count"));
365       profile_pb->mutable_comment()->Add(string_table_.GetIndex(
366           "Model float operations (Only available if defined)."));
367     } else {
368       absl::FPrintF(stderr, "pprof doesn't support selecting: %s\n", type);
369     }
370 
371     for (const string& str : string_table_.strings()) {
372       *profile_pb->mutable_string_table()->Add() = str;
373     }
374     for (const auto& sample_it : samples_->samples()) {
375       // TODO(xpan): Consider swap.
376       profile_pb->mutable_sample()->Add()->MergeFrom(sample_it.second);
377     }
378     for (const auto& function_it : func_table_->functions()) {
379       profile_pb->mutable_function()->Add()->MergeFrom(function_it.second);
380     }
381     for (const auto& location_it : loc_table_->locations()) {
382       profile_pb->mutable_location()->Add()->MergeFrom(location_it.second);
383     }
384   }
385 
386   const Options* opts_;
387   StringTable string_table_;
388   std::unique_ptr<FunctionTable> func_table_;
389   std::unique_ptr<LocationTable> loc_table_;
390   std::unique_ptr<Samples> samples_;
391 };
392 }  // namespace
393 
AddNode(TFGraphNode * node)394 void TFCode::AddNode(TFGraphNode* node) {
395   if (!node->call_stack() || node->call_stack()->traces().empty()) {
396     return;
397   }
398   // We infer the forward operation name from gradient op name. So, we can
399   // map gradient op traces to forward op traces.
400   // E.g. gradient node of 'inp_1/Conv2D' would be 'gradients/inp_1/Conv2D_grad.
401   string forward_name;
402   if (IsGradNode(node->name(), &forward_name)) {
403     auto grad_nodes_it = grad_nodes_.find(forward_name);
404     if (grad_nodes_it != grad_nodes_.end()) {
405       grad_nodes_it->second.push_back(node);
406     } else {
407       grad_nodes_.insert(
408           std::pair<string, std::vector<TFGraphNode*>>(forward_name, {node}));
409     }
410     return;
411   } else {
412     forward_nodes_[node->name()] = node;
413   }
414 
415   if (!root_) {
416     graph_root_.reset(new TFMultiGraphNode(kTFProfRoot));
417     root_.reset(new CodeNode(graph_root_.get(), nullptr, ""));
418   }
419 
420   CodeNode* pre_code_node = root_.get();
421   // TODO(xpan): Consider to release CodeDef after TFCode is built. It
422   // takes a lot of memory.
423   std::set<string> traces;
424   for (int i = 0; i < node->call_stack()->traces().size(); ++i) {
425     // Unlike op name, which is globally unique, trace name is only unique
426     // w.r.t. it's parent.
427     const string& trace = GetTraceString(node->call_stack()->traces().at(i));
428     traces.insert(trace);
429     pre_code_node = pre_code_node->AddChildren(
430         trace, &node->call_stack()->traces().at(i), "");
431     if (i == node->call_stack()->traces().size() - 1) {
432       pre_code_node->node->AddGraphNode(node);
433     }
434   }
435 }
436 
Build()437 void TFCode::Build() {
438   int64 unaccounted_nodes = 0;
439   for (auto it : grad_nodes_) {
440     const string& forward_name = it.first;
441     auto forward_it = forward_nodes_.find(forward_name);
442     if (forward_it == forward_nodes_.end()) {
443       unaccounted_nodes += 1;
444       continue;
445     }
446     TFGraphNode* fn = forward_it->second;
447     CodeNode* leaf = nullptr;
448     CodeNode* pre_code_node = root_.get();
449     for (int i = 0; i < fn->call_stack()->traces().size(); ++i) {
450       const string& trace =
451           GetTraceString(fn->call_stack()->traces().at(i)) + kGradientSuffix;
452       pre_code_node = pre_code_node->AddChildren(
453           trace, &fn->call_stack()->traces().at(i), kGradientSuffix);
454       if (i == fn->call_stack()->traces().size() - 1) {
455         leaf = pre_code_node;
456       }
457     }
458     for (TFGraphNode* gn : it.second) {
459       leaf->node->AddGraphNode(gn);
460     }
461   }
462   if (unaccounted_nodes > 0) {
463     absl::FPrintF(stderr, "%d gradient nodes not accounted\n",
464                   unaccounted_nodes);
465   }
466 }
467 
ShowInternal(const Options & opts,Timeline * timeline)468 const ShowMultiNode* TFCode::ShowInternal(const Options& opts,
469                                           Timeline* timeline) {
470   root_->ResetTotalStats();
471   if (opts.output_type == kOutput[3]) {
472     if (opts.select.size() != 1) {
473       absl::FPrintF(stderr, "Can only select 1 attribute for pprof output.\n");
474       return root_.get();
475     }
476     string select = *opts.select.begin();
477     if (select != kShown[0] && select != kShown[1] && select != kShown[2] &&
478         select != kShown[3] && select != kShown[9] && select != kShown[10] &&
479         select != kShown[11] && select != kShown[12] && select != kShown[13]) {
480       absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", select);
481       return root_.get();
482     }
483   }
484   if (opts.account_displayed_op_only) {
485     absl::FPrintF(stderr,
486                   "Note: code view ignores account_displayed_op_only\n");
487   }
488 
489   std::vector<CodeNode*> roots = Account(root_->children, opts);
490   root_->show_children.clear();
491   for (CodeNode* n : roots) {
492     root_->AggregateTotalStats(n);
493   }
494 
495   if (opts.start_name_regexes.size() != 1 ||
496       opts.start_name_regexes[0] != ".*") {
497     roots = SearchRoot(roots, opts.start_name_regexes);
498   }
499 
500   root_->show_children.assign(roots.begin(), roots.end());
501 
502   CodeNode* root = PrintScope({root_.get()}, opts, 1, 0)[0];
503 
504   root->formatted_str = FormatLegend(opts) + root->formatted_str;
505 
506   if (opts.output_type == kOutput[3]) {
507     std::vector<uint64> call_ids;
508     pprof_profile_.reset(new PprofProfileImpl(&opts));
509     Format(root, root->show_children, opts, &root->formatted_str,
510            root->mutable_proto(), &call_ids);
511     Status s = pprof_profile_->WritePprofProfile(
512         opts.output_options.at(kPprofOpts[0]));
513     if (!s.ok()) {
514       absl::FPrintF(stderr, "%s\n", s.ToString());
515     }
516   } else {
517     Format(root, root->show_children, opts, &root->formatted_str,
518            root->mutable_proto(), nullptr);
519     if (timeline) {
520       timeline->GenerateCodeTimeline(root);
521     }
522   }
523   return root;
524 }
525 
Format(const CodeNode * root,const std::vector<CodeNode * > & nodes,const Options & opts,string * display_str,MultiGraphNodeProto * proto,std::vector<uint64> * call_ids)526 void TFCode::Format(const CodeNode* root, const std::vector<CodeNode*>& nodes,
527                     const Options& opts, string* display_str,
528                     MultiGraphNodeProto* proto, std::vector<uint64>* call_ids) {
529   if (nodes.empty() && root->has_trace() && opts.output_type == kOutput[3]) {
530     pprof_profile_->AddSample(root, call_ids);
531   }
532 
533   for (CodeNode* node : nodes) {
534     if (root->has_trace() && opts.output_type == kOutput[3]) {
535       uint64 loc_id = pprof_profile_->AddLocation(node, root);
536       call_ids->push_back(loc_id);
537     }
538     display_str->append(node->formatted_str);
539     MultiGraphNodeProto* child = proto->add_children();
540     child->MergeFrom(node->proto());
541     Format(node, node->show_children, opts, display_str, child, call_ids);
542     if (root->has_trace() && opts.output_type == kOutput[3]) {
543       call_ids->pop_back();
544     }
545   }
546 }
547 
SearchRoot(std::vector<CodeNode * > roots,const std::vector<string> & regexes)548 std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
549                                           const std::vector<string>& regexes) {
550   std::vector<CodeNode*> res;
551   if (roots.empty()) {
552     return res;
553   }
554   for (CodeNode* root : roots) {
555     bool match_start_node = false;
556     for (const string& regex : regexes) {
557       if (RE2::FullMatch(root->name(), regex)) {
558         res.push_back(root);
559         match_start_node = true;
560         break;
561       }
562     }
563     if (match_start_node) {
564       // Found a start node at this branch, no need to continue.
565       continue;
566     }
567     std::vector<CodeNode*> nroots = SearchRoot(root->show_children, regexes);
568     res.insert(res.end(), nroots.begin(), nroots.end());
569   }
570   return res;
571 }
572 
PrintScope(const std::vector<CodeNode * > roots,const Options & opts,int depth,int last_ident)573 std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
574                                           const Options& opts, int depth,
575                                           int last_ident) {
576   std::vector<CodeNode*> show_nodes;
577 
578   for (CodeNode* node : roots) {
579     if (ShouldTrim(node, opts.trim_name_regexes) || depth > opts.max_depth) {
580       continue;
581     }
582     int ident = last_ident;
583     bool show = ShouldShow(node, opts, depth);
584     if (show) ident += 2;
585 
586     std::vector<CodeNode*> show_cnodes =
587         PrintScope(node->show_children, opts, depth + 1, ident);
588     if (show) {
589       node->show_children.clear();
590 
591       show_cnodes = SortNodes(show_cnodes, opts);
592       for (CodeNode* sc : show_cnodes) {
593         node->show_children.push_back(sc);
594       }
595 
596       node->formatted_str = FormatNode(node, opts, last_ident);
597 
598       if (opts.select.find(kShown[4]) != opts.select.end()) {
599         absl::FPrintF(stderr, "code view has no tensor value to show\n");
600       }
601       show_nodes.push_back(node);
602     } else {
603       show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
604                         show_cnodes.end());
605     }
606   }
607   return show_nodes;
608 }
609 
Account(const std::vector<CodeNode * > & roots,const Options & opts)610 std::vector<CodeNode*> TFCode::Account(const std::vector<CodeNode*>& roots,
611                                        const Options& opts) {
612   std::vector<CodeNode*> act_nodes;
613 
614   for (CodeNode* node : roots) {
615     node->ResetTotalStats();
616     std::vector<CodeNode*> act_cnodes = Account(node->children, opts);
617     node->account = ReAccount(node, opts);
618     if (node->account || !act_cnodes.empty()) {
619       node->show_children.clear();
620       node->ResetTotalStats();
621       node->AddSelfToTotalStats();
622       for (CodeNode* c : act_cnodes) {
623         node->AggregateTotalStats(c);
624         node->show_children.push_back(c);
625       }
626       act_nodes.push_back(node);
627     }
628   }
629   return act_nodes;
630 }
631 
FormatNodeMemory(CodeNode * node,int64 bytes,int64 total_bytes) const632 string TFCode::FormatNodeMemory(CodeNode* node, int64 bytes,
633                                 int64 total_bytes) const {
634   string memory = FormatMemory(total_bytes);
635   if (node->account) {
636     memory = FormatMemory(bytes) + "/" + memory;
637   } else {
638     memory = "--/" + memory;
639   }
640   return memory;
641 }
642 
FormatNode(CodeNode * node,const Options & opts,int64 indent) const643 string TFCode::FormatNode(CodeNode* node, const Options& opts,
644                           int64 indent) const {
645   std::vector<string> attrs;
646   if (opts.select.find(kShown[0]) != opts.select.end()) {
647     attrs.push_back(FormatNodeMemory(node, node->proto().requested_bytes(),
648                                      node->proto().total_requested_bytes()));
649   }
650   if (opts.select.find(kShown[11]) != opts.select.end()) {
651     attrs.push_back(FormatNodeMemory(node, node->proto().peak_bytes(),
652                                      node->proto().total_peak_bytes()));
653   }
654   if (opts.select.find(kShown[12]) != opts.select.end()) {
655     attrs.push_back(FormatNodeMemory(node, node->proto().residual_bytes(),
656                                      node->proto().total_residual_bytes()));
657   }
658   if (opts.select.find(kShown[13]) != opts.select.end()) {
659     attrs.push_back(FormatNodeMemory(node, node->proto().output_bytes(),
660                                      node->proto().total_output_bytes()));
661   }
662 
663   std::vector<string> time_attrs = FormatTimes(node, opts);
664   attrs.insert(attrs.end(), time_attrs.begin(), time_attrs.end());
665 
666   if (opts.select.find(kShown[2]) != opts.select.end()) {
667     string params = FormatNumber(node->proto().total_parameters()) + " params";
668     if (node->account) {
669       params = FormatNumber(node->proto().parameters()) + "/" + params;
670     } else {
671       params = "--/" + params;
672     }
673     attrs.push_back(params);
674   }
675 
676   if (opts.select.find(kShown[3]) != opts.select.end()) {
677     string fops = FormatNumber(node->proto().total_float_ops()) + " flops";
678     if (node->account) {
679       fops = FormatNumber(node->proto().float_ops()) + "/" + fops;
680     } else {
681       fops = "--/" + fops;
682     }
683     attrs.push_back(fops);
684   }
685 
686   if (opts.select.find(kShown[5]) != opts.select.end() &&
687       !node->node->devices().empty()) {
688     attrs.push_back(absl::StrJoin(node->node->devices(), "|"));
689   }
690   if (opts.select.find(kShown[6]) != opts.select.end()) {
691     std::set<string> op_types = node->node->op_types();
692     attrs.push_back(absl::StrJoin(op_types, "|"));
693   }
694   if (opts.select.find(kShown[7]) != opts.select.end()) {
695     // TODO(xpan): Make op count available in code view?
696     attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[7]));
697   }
698   if (opts.select.find(kShown[8]) != opts.select.end()) {
699     attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[8]));
700   }
701 
702   return absl::StrFormat("%s%s (%s)\n", std::string(indent, ' '), node->name(),
703                          absl::StrJoin(attrs, ", "));
704 }
705 }  // namespace tfprof
706 }  // namespace tensorflow
707