• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdio.h>
17 #include <stdlib.h>
18 
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_split.h"
27 #include "linenoise.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/init_main.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/profiler/internal/advisor/tfprof_advisor.h"
33 #include "tensorflow/core/profiler/internal/tfprof_stats.h"
34 #include "tensorflow/core/profiler/internal/tfprof_utils.h"
35 #include "tensorflow/core/profiler/tfprof_log.pb.h"
36 #include "tensorflow/core/profiler/tfprof_options.h"
37 #include "tensorflow/core/util/command_line_flags.h"
38 
39 namespace tensorflow {
40 namespace tfprof {
completion(const char * buf,linenoiseCompletions * lc)41 void completion(const char* buf, linenoiseCompletions* lc) {
42   string buf_str = buf;
43   if (buf_str.find(' ') == buf_str.npos) {
44     for (const char* opt : kCmds) {
45       if (string(opt).find(buf_str) == 0) {
46         linenoiseAddCompletion(lc, opt);
47       }
48     }
49     return;
50   }
51 
52   string prefix;
53   int last_dash = buf_str.find_last_of(' ');
54   if (last_dash != string::npos) {
55     prefix = buf_str.substr(0, last_dash + 1);
56     buf_str = buf_str.substr(last_dash + 1, kint32max);
57   }
58   for (const char* opt : kOptions) {
59     if (string(opt).find(buf_str) == 0) {
60       linenoiseAddCompletion(lc, (prefix + opt).c_str());
61     }
62   }
63 }
64 
Run(int argc,char ** argv)65 int Run(int argc, char** argv) {
66   string FLAGS_profile_path = "";
67   string FLAGS_graph_path = "";
68   string FLAGS_run_meta_path = "";
69   string FLAGS_op_log_path = "";
70   string FLAGS_checkpoint_path = "";
71   int32_t FLAGS_max_depth = 10;
72   int64_t FLAGS_min_bytes = 0;
73   int64_t FLAGS_min_peak_bytes = 0;
74   int64_t FLAGS_min_residual_bytes = 0;
75   int64_t FLAGS_min_output_bytes = 0;
76   int64_t FLAGS_min_micros = 0;
77   int64_t FLAGS_min_accelerator_micros = 0;
78   int64_t FLAGS_min_cpu_micros = 0;
79   int64_t FLAGS_min_params = 0;
80   int64_t FLAGS_min_float_ops = 0;
81   int64_t FLAGS_min_occurrence = 0;
82   int64_t FLAGS_step = -1;
83   string FLAGS_order_by = "name";
84   string FLAGS_account_type_regexes = ".*";
85   string FLAGS_start_name_regexes = ".*";
86   string FLAGS_trim_name_regexes = "";
87   string FLAGS_show_name_regexes = ".*";
88   string FLAGS_hide_name_regexes;
89   bool FLAGS_account_displayed_op_only = false;
90   string FLAGS_select = "micros";
91   string FLAGS_output = "";
92   for (int i = 0; i < argc; i++) {
93     absl::FPrintF(stderr, "%s\n", argv[i]);
94   }
95 
96   std::vector<Flag> flag_list = {
97       Flag("profile_path", &FLAGS_profile_path, "Profile binary file name."),
98       Flag("graph_path", &FLAGS_graph_path, "GraphDef proto text file name"),
99       Flag("run_meta_path", &FLAGS_run_meta_path,
100            "Comma-separated list of RunMetadata proto binary "
101            "files. Each file is given step number 0,1,2,etc"),
102       Flag("op_log_path", &FLAGS_op_log_path,
103            "tensorflow::tfprof::OpLogProto proto binary file name"),
104       Flag("checkpoint_path", &FLAGS_checkpoint_path,
105            "TensorFlow Checkpoint file name"),
106       Flag("max_depth", &FLAGS_max_depth, "max depth"),
107       Flag("min_bytes", &FLAGS_min_bytes, "min_bytes"),
108       Flag("min_peak_bytes", &FLAGS_min_peak_bytes, "min_peak_bytes"),
109       Flag("min_residual_bytes", &FLAGS_min_residual_bytes,
110            "min_residual_bytes"),
111       Flag("min_output_bytes", &FLAGS_min_output_bytes, "min_output_bytes"),
112       Flag("min_micros", &FLAGS_min_micros, "min micros"),
113       Flag("min_accelerator_micros", &FLAGS_min_accelerator_micros,
114            "min accelerator_micros"),
115       Flag("min_cpu_micros", &FLAGS_min_cpu_micros, "min_cpu_micros"),
116       Flag("min_params", &FLAGS_min_params, "min params"),
117       Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"),
118       Flag("min_occurrence", &FLAGS_min_occurrence, "min occurrence"),
119       Flag("step", &FLAGS_step,
120            "The stats of which step to use. By default average"),
121       Flag("order_by", &FLAGS_order_by, "order by"),
122       Flag("account_type_regexes", &FLAGS_start_name_regexes,
123            "start name regexes"),
124       Flag("trim_name_regexes", &FLAGS_trim_name_regexes, "trim name regexes"),
125       Flag("show_name_regexes", &FLAGS_show_name_regexes, "show name regexes"),
126       Flag("hide_name_regexes", &FLAGS_hide_name_regexes, "hide name regexes"),
127       Flag("account_displayed_op_only", &FLAGS_account_displayed_op_only,
128            "account displayed op only"),
129       Flag("select", &FLAGS_select, "select"),
130       Flag("output", &FLAGS_output, "output"),
131   };
132   string usage = Flags::Usage(argv[0], flag_list);
133   bool parse_ok = Flags::Parse(&argc, argv, flag_list);
134   if (!parse_ok) {
135     absl::PrintF("%s", usage);
136     return (2);
137   }
138   port::InitMain(argv[0], &argc, &argv);
139 
140   if (!FLAGS_profile_path.empty() &&
141       (!FLAGS_graph_path.empty() || !FLAGS_run_meta_path.empty())) {
142     absl::FPrintF(stderr,
143                   "--profile_path is set, do not set --graph_path or "
144                   "--run_meta_path\n");
145     return 1;
146   }
147 
148   std::vector<string> account_type_regexes =
149       absl::StrSplit(FLAGS_account_type_regexes, ',', absl::SkipEmpty());
150   std::vector<string> start_name_regexes =
151       absl::StrSplit(FLAGS_start_name_regexes, ',', absl::SkipEmpty());
152   std::vector<string> trim_name_regexes =
153       absl::StrSplit(FLAGS_trim_name_regexes, ',', absl::SkipEmpty());
154   std::vector<string> show_name_regexes =
155       absl::StrSplit(FLAGS_show_name_regexes, ',', absl::SkipEmpty());
156   std::vector<string> hide_name_regexes =
157       absl::StrSplit(FLAGS_hide_name_regexes, ',', absl::SkipEmpty());
158   std::vector<string> select =
159       absl::StrSplit(FLAGS_select, ',', absl::SkipEmpty());
160 
161   string output_type;
162   std::map<string, string> output_options;
163   Status s = ParseOutput(FLAGS_output, &output_type, &output_options);
164   CHECK(s.ok()) << s.ToString();
165 
166   string cmd = "";
167   if (argc == 1 && FLAGS_graph_path.empty() && FLAGS_profile_path.empty() &&
168       FLAGS_run_meta_path.empty()) {
169     PrintHelp();
170     return 0;
171   } else if (argc > 1) {
172     if (string(argv[1]) == kCmds[6]) {
173       PrintHelp();
174       return 0;
175     }
176     if (string(argv[1]) == kCmds[0] || string(argv[1]) == kCmds[1] ||
177         string(argv[1]) == kCmds[2] || string(argv[1]) == kCmds[3] ||
178         string(argv[1]) == kCmds[4]) {
179       cmd = argv[1];
180     }
181   }
182 
183   absl::PrintF("Reading Files...\n");
184   std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader;
185   TF_Status* status = TF_NewStatus();
186   if (!FLAGS_checkpoint_path.empty()) {
187     ckpt_reader.reset(
188         new checkpoint::CheckpointReader(FLAGS_checkpoint_path, status));
189     if (TF_GetCode(status) != TF_OK) {
190       absl::FPrintF(stderr, "%s\n", TF_Message(status));
191       TF_DeleteStatus(status);
192       return 1;
193     }
194     TF_DeleteStatus(status);
195   }
196 
197   std::unique_ptr<TFStats> tf_stat;
198   if (!FLAGS_profile_path.empty()) {
199     tf_stat.reset(new TFStats(FLAGS_profile_path, std::move(ckpt_reader)));
200   } else {
201     absl::PrintF(
202         "Try to use a single --profile_path instead of "
203         "graph_path,op_log_path,run_meta_path\n");
204     std::unique_ptr<GraphDef> graph(new GraphDef());
205     if (!FLAGS_graph_path.empty()) {
206       s = ReadProtoFile(Env::Default(), FLAGS_graph_path, graph.get(), false);
207       if (!s.ok()) {
208         absl::FPrintF(stderr, "Failed to read graph_path: %s\n", s.ToString());
209         return 1;
210       }
211     }
212 
213     std::unique_ptr<OpLogProto> op_log(new OpLogProto());
214     if (!FLAGS_op_log_path.empty()) {
215       string op_log_str;
216       s = ReadFileToString(Env::Default(), FLAGS_op_log_path, &op_log_str);
217       if (!s.ok()) {
218         absl::FPrintF(stderr, "Failed to read op_log_path: %s\n", s.ToString());
219         return 1;
220       }
221       if (!ParseProtoUnlimited(op_log.get(), op_log_str)) {
222         absl::FPrintF(stderr, "Failed to parse op_log_path\n");
223         return 1;
224       }
225     }
226     tf_stat.reset(new TFStats(std::move(graph), nullptr, std::move(op_log),
227                               std::move(ckpt_reader)));
228 
229     std::vector<string> run_meta_files =
230         absl::StrSplit(FLAGS_run_meta_path, ',', absl::SkipEmpty());
231     for (int i = 0; i < run_meta_files.size(); ++i) {
232       std::unique_ptr<RunMetadata> run_meta(new RunMetadata());
233       s = ReadProtoFile(Env::Default(), run_meta_files[i], run_meta.get(),
234                         true);
235       if (!s.ok()) {
236         absl::FPrintF(stderr, "Failed to read run_meta_path %s. Status: %s\n",
237                       run_meta_files[i], s.ToString());
238         return 1;
239       }
240       tf_stat->AddRunMeta(i, std::move(run_meta));
241       absl::FPrintF(stdout, "run graph coverage: %.2f\n",
242                     tf_stat->run_coverage());
243     }
244   }
245 
246   if (cmd == kCmds[4]) {
247     tf_stat->BuildAllViews();
248     Advisor(tf_stat.get()).Advise(Advisor::DefaultOptions());
249     return 0;
250   }
251 
252   Options opts(
253       FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_peak_bytes,
254       FLAGS_min_residual_bytes, FLAGS_min_output_bytes, FLAGS_min_micros,
255       FLAGS_min_accelerator_micros, FLAGS_min_cpu_micros, FLAGS_min_params,
256       FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_step, FLAGS_order_by,
257       account_type_regexes, start_name_regexes, trim_name_regexes,
258       show_name_regexes, hide_name_regexes, FLAGS_account_displayed_op_only,
259       select, output_type, output_options);
260 
261   if (cmd == kCmds[2] || cmd == kCmds[3]) {
262     tf_stat->BuildView(cmd);
263     tf_stat->ShowMultiGraphNode(cmd, opts);
264     return 0;
265   } else if (cmd == kCmds[0] || cmd == kCmds[1]) {
266     tf_stat->BuildView(cmd);
267     tf_stat->ShowGraphNode(cmd, opts);
268     return 0;
269   }
270 
271   linenoiseSetCompletionCallback(completion);
272   linenoiseHistoryLoad(".tfprof_history.txt");
273 
274   bool looped = false;
275   while (true) {
276     char* line = linenoise("tfprof> ");
277     if (line == nullptr) {
278       if (!looped) {
279         absl::FPrintF(stderr,
280                       "Cannot start interactive shell, "
281                       "use 'bazel-bin' instead of 'bazel run'.\n");
282       }
283       break;
284     }
285     looped = true;
286     string line_s = line;
287     free(line);
288 
289     if (line_s.empty()) {
290       absl::PrintF("%s", opts.ToString());
291       continue;
292     }
293     linenoiseHistoryAdd(line_s.c_str());
294     linenoiseHistorySave(".tfprof_history.txt");
295 
296     Options new_opts = opts;
297     Status s = ParseCmdLine(line_s, &cmd, &new_opts);
298     if (!s.ok()) {
299       absl::FPrintF(stderr, "E: %s\n", s.ToString());
300       continue;
301     }
302     if (cmd == kCmds[5]) {
303       opts = new_opts;
304     } else if (cmd == kCmds[6]) {
305       PrintHelp();
306     } else if (cmd == kCmds[2] || cmd == kCmds[3]) {
307       tf_stat->BuildView(cmd);
308       tf_stat->ShowMultiGraphNode(cmd, new_opts);
309     } else if (cmd == kCmds[0] || cmd == kCmds[1]) {
310       tf_stat->BuildView(cmd);
311       tf_stat->ShowGraphNode(cmd, new_opts);
312     } else if (cmd == kCmds[4]) {
313       tf_stat->BuildAllViews();
314       Advisor(tf_stat.get()).Advise(Advisor::DefaultOptions());
315     }
316   }
317   return 0;
318 }
319 }  // namespace tfprof
320 }  // namespace tensorflow
321 
main(int argc,char ** argv)322 int main(int argc, char** argv) { return tensorflow::tfprof::Run(argc, argv); }
323