1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdio.h>
17 #include <stdlib.h>
18
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_split.h"
27 #include "linenoise.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/init_main.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/profiler/internal/advisor/tfprof_advisor.h"
33 #include "tensorflow/core/profiler/internal/tfprof_stats.h"
34 #include "tensorflow/core/profiler/internal/tfprof_utils.h"
35 #include "tensorflow/core/profiler/tfprof_log.pb.h"
36 #include "tensorflow/core/profiler/tfprof_options.h"
37 #include "tensorflow/core/util/command_line_flags.h"
38
39 namespace tensorflow {
40 namespace tfprof {
completion(const char * buf,linenoiseCompletions * lc)41 void completion(const char* buf, linenoiseCompletions* lc) {
42 string buf_str = buf;
43 if (buf_str.find(' ') == buf_str.npos) {
44 for (const char* opt : kCmds) {
45 if (string(opt).find(buf_str) == 0) {
46 linenoiseAddCompletion(lc, opt);
47 }
48 }
49 return;
50 }
51
52 string prefix;
53 int last_dash = buf_str.find_last_of(' ');
54 if (last_dash != string::npos) {
55 prefix = buf_str.substr(0, last_dash + 1);
56 buf_str = buf_str.substr(last_dash + 1, kint32max);
57 }
58 for (const char* opt : kOptions) {
59 if (string(opt).find(buf_str) == 0) {
60 linenoiseAddCompletion(lc, (prefix + opt).c_str());
61 }
62 }
63 }
64
Run(int argc,char ** argv)65 int Run(int argc, char** argv) {
66 string FLAGS_profile_path = "";
67 string FLAGS_graph_path = "";
68 string FLAGS_run_meta_path = "";
69 string FLAGS_op_log_path = "";
70 string FLAGS_checkpoint_path = "";
71 int32_t FLAGS_max_depth = 10;
72 int64_t FLAGS_min_bytes = 0;
73 int64_t FLAGS_min_peak_bytes = 0;
74 int64_t FLAGS_min_residual_bytes = 0;
75 int64_t FLAGS_min_output_bytes = 0;
76 int64_t FLAGS_min_micros = 0;
77 int64_t FLAGS_min_accelerator_micros = 0;
78 int64_t FLAGS_min_cpu_micros = 0;
79 int64_t FLAGS_min_params = 0;
80 int64_t FLAGS_min_float_ops = 0;
81 int64_t FLAGS_min_occurrence = 0;
82 int64_t FLAGS_step = -1;
83 string FLAGS_order_by = "name";
84 string FLAGS_account_type_regexes = ".*";
85 string FLAGS_start_name_regexes = ".*";
86 string FLAGS_trim_name_regexes = "";
87 string FLAGS_show_name_regexes = ".*";
88 string FLAGS_hide_name_regexes;
89 bool FLAGS_account_displayed_op_only = false;
90 string FLAGS_select = "micros";
91 string FLAGS_output = "";
92 for (int i = 0; i < argc; i++) {
93 absl::FPrintF(stderr, "%s\n", argv[i]);
94 }
95
96 std::vector<Flag> flag_list = {
97 Flag("profile_path", &FLAGS_profile_path, "Profile binary file name."),
98 Flag("graph_path", &FLAGS_graph_path, "GraphDef proto text file name"),
99 Flag("run_meta_path", &FLAGS_run_meta_path,
100 "Comma-separated list of RunMetadata proto binary "
101 "files. Each file is given step number 0,1,2,etc"),
102 Flag("op_log_path", &FLAGS_op_log_path,
103 "tensorflow::tfprof::OpLogProto proto binary file name"),
104 Flag("checkpoint_path", &FLAGS_checkpoint_path,
105 "TensorFlow Checkpoint file name"),
106 Flag("max_depth", &FLAGS_max_depth, "max depth"),
107 Flag("min_bytes", &FLAGS_min_bytes, "min_bytes"),
108 Flag("min_peak_bytes", &FLAGS_min_peak_bytes, "min_peak_bytes"),
109 Flag("min_residual_bytes", &FLAGS_min_residual_bytes,
110 "min_residual_bytes"),
111 Flag("min_output_bytes", &FLAGS_min_output_bytes, "min_output_bytes"),
112 Flag("min_micros", &FLAGS_min_micros, "min micros"),
113 Flag("min_accelerator_micros", &FLAGS_min_accelerator_micros,
114 "min accelerator_micros"),
115 Flag("min_cpu_micros", &FLAGS_min_cpu_micros, "min_cpu_micros"),
116 Flag("min_params", &FLAGS_min_params, "min params"),
117 Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"),
118 Flag("min_occurrence", &FLAGS_min_occurrence, "min occurrence"),
119 Flag("step", &FLAGS_step,
120 "The stats of which step to use. By default average"),
121 Flag("order_by", &FLAGS_order_by, "order by"),
122 Flag("account_type_regexes", &FLAGS_start_name_regexes,
123 "start name regexes"),
124 Flag("trim_name_regexes", &FLAGS_trim_name_regexes, "trim name regexes"),
125 Flag("show_name_regexes", &FLAGS_show_name_regexes, "show name regexes"),
126 Flag("hide_name_regexes", &FLAGS_hide_name_regexes, "hide name regexes"),
127 Flag("account_displayed_op_only", &FLAGS_account_displayed_op_only,
128 "account displayed op only"),
129 Flag("select", &FLAGS_select, "select"),
130 Flag("output", &FLAGS_output, "output"),
131 };
132 string usage = Flags::Usage(argv[0], flag_list);
133 bool parse_ok = Flags::Parse(&argc, argv, flag_list);
134 if (!parse_ok) {
135 absl::PrintF("%s", usage);
136 return (2);
137 }
138 port::InitMain(argv[0], &argc, &argv);
139
140 if (!FLAGS_profile_path.empty() &&
141 (!FLAGS_graph_path.empty() || !FLAGS_run_meta_path.empty())) {
142 absl::FPrintF(stderr,
143 "--profile_path is set, do not set --graph_path or "
144 "--run_meta_path\n");
145 return 1;
146 }
147
148 std::vector<string> account_type_regexes =
149 absl::StrSplit(FLAGS_account_type_regexes, ',', absl::SkipEmpty());
150 std::vector<string> start_name_regexes =
151 absl::StrSplit(FLAGS_start_name_regexes, ',', absl::SkipEmpty());
152 std::vector<string> trim_name_regexes =
153 absl::StrSplit(FLAGS_trim_name_regexes, ',', absl::SkipEmpty());
154 std::vector<string> show_name_regexes =
155 absl::StrSplit(FLAGS_show_name_regexes, ',', absl::SkipEmpty());
156 std::vector<string> hide_name_regexes =
157 absl::StrSplit(FLAGS_hide_name_regexes, ',', absl::SkipEmpty());
158 std::vector<string> select =
159 absl::StrSplit(FLAGS_select, ',', absl::SkipEmpty());
160
161 string output_type;
162 std::map<string, string> output_options;
163 Status s = ParseOutput(FLAGS_output, &output_type, &output_options);
164 CHECK(s.ok()) << s.ToString();
165
166 string cmd = "";
167 if (argc == 1 && FLAGS_graph_path.empty() && FLAGS_profile_path.empty() &&
168 FLAGS_run_meta_path.empty()) {
169 PrintHelp();
170 return 0;
171 } else if (argc > 1) {
172 if (string(argv[1]) == kCmds[6]) {
173 PrintHelp();
174 return 0;
175 }
176 if (string(argv[1]) == kCmds[0] || string(argv[1]) == kCmds[1] ||
177 string(argv[1]) == kCmds[2] || string(argv[1]) == kCmds[3] ||
178 string(argv[1]) == kCmds[4]) {
179 cmd = argv[1];
180 }
181 }
182
183 absl::PrintF("Reading Files...\n");
184 std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader;
185 TF_Status* status = TF_NewStatus();
186 if (!FLAGS_checkpoint_path.empty()) {
187 ckpt_reader.reset(
188 new checkpoint::CheckpointReader(FLAGS_checkpoint_path, status));
189 if (TF_GetCode(status) != TF_OK) {
190 absl::FPrintF(stderr, "%s\n", TF_Message(status));
191 TF_DeleteStatus(status);
192 return 1;
193 }
194 TF_DeleteStatus(status);
195 }
196
197 std::unique_ptr<TFStats> tf_stat;
198 if (!FLAGS_profile_path.empty()) {
199 tf_stat.reset(new TFStats(FLAGS_profile_path, std::move(ckpt_reader)));
200 } else {
201 absl::PrintF(
202 "Try to use a single --profile_path instead of "
203 "graph_path,op_log_path,run_meta_path\n");
204 std::unique_ptr<GraphDef> graph(new GraphDef());
205 if (!FLAGS_graph_path.empty()) {
206 s = ReadProtoFile(Env::Default(), FLAGS_graph_path, graph.get(), false);
207 if (!s.ok()) {
208 absl::FPrintF(stderr, "Failed to read graph_path: %s\n", s.ToString());
209 return 1;
210 }
211 }
212
213 std::unique_ptr<OpLogProto> op_log(new OpLogProto());
214 if (!FLAGS_op_log_path.empty()) {
215 string op_log_str;
216 s = ReadFileToString(Env::Default(), FLAGS_op_log_path, &op_log_str);
217 if (!s.ok()) {
218 absl::FPrintF(stderr, "Failed to read op_log_path: %s\n", s.ToString());
219 return 1;
220 }
221 if (!ParseProtoUnlimited(op_log.get(), op_log_str)) {
222 absl::FPrintF(stderr, "Failed to parse op_log_path\n");
223 return 1;
224 }
225 }
226 tf_stat.reset(new TFStats(std::move(graph), nullptr, std::move(op_log),
227 std::move(ckpt_reader)));
228
229 std::vector<string> run_meta_files =
230 absl::StrSplit(FLAGS_run_meta_path, ',', absl::SkipEmpty());
231 for (int i = 0; i < run_meta_files.size(); ++i) {
232 std::unique_ptr<RunMetadata> run_meta(new RunMetadata());
233 s = ReadProtoFile(Env::Default(), run_meta_files[i], run_meta.get(),
234 true);
235 if (!s.ok()) {
236 absl::FPrintF(stderr, "Failed to read run_meta_path %s. Status: %s\n",
237 run_meta_files[i], s.ToString());
238 return 1;
239 }
240 tf_stat->AddRunMeta(i, std::move(run_meta));
241 absl::FPrintF(stdout, "run graph coverage: %.2f\n",
242 tf_stat->run_coverage());
243 }
244 }
245
246 if (cmd == kCmds[4]) {
247 tf_stat->BuildAllViews();
248 Advisor(tf_stat.get()).Advise(Advisor::DefaultOptions());
249 return 0;
250 }
251
252 Options opts(
253 FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_peak_bytes,
254 FLAGS_min_residual_bytes, FLAGS_min_output_bytes, FLAGS_min_micros,
255 FLAGS_min_accelerator_micros, FLAGS_min_cpu_micros, FLAGS_min_params,
256 FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_step, FLAGS_order_by,
257 account_type_regexes, start_name_regexes, trim_name_regexes,
258 show_name_regexes, hide_name_regexes, FLAGS_account_displayed_op_only,
259 select, output_type, output_options);
260
261 if (cmd == kCmds[2] || cmd == kCmds[3]) {
262 tf_stat->BuildView(cmd);
263 tf_stat->ShowMultiGraphNode(cmd, opts);
264 return 0;
265 } else if (cmd == kCmds[0] || cmd == kCmds[1]) {
266 tf_stat->BuildView(cmd);
267 tf_stat->ShowGraphNode(cmd, opts);
268 return 0;
269 }
270
271 linenoiseSetCompletionCallback(completion);
272 linenoiseHistoryLoad(".tfprof_history.txt");
273
274 bool looped = false;
275 while (true) {
276 char* line = linenoise("tfprof> ");
277 if (line == nullptr) {
278 if (!looped) {
279 absl::FPrintF(stderr,
280 "Cannot start interactive shell, "
281 "use 'bazel-bin' instead of 'bazel run'.\n");
282 }
283 break;
284 }
285 looped = true;
286 string line_s = line;
287 free(line);
288
289 if (line_s.empty()) {
290 absl::PrintF("%s", opts.ToString());
291 continue;
292 }
293 linenoiseHistoryAdd(line_s.c_str());
294 linenoiseHistorySave(".tfprof_history.txt");
295
296 Options new_opts = opts;
297 Status s = ParseCmdLine(line_s, &cmd, &new_opts);
298 if (!s.ok()) {
299 absl::FPrintF(stderr, "E: %s\n", s.ToString());
300 continue;
301 }
302 if (cmd == kCmds[5]) {
303 opts = new_opts;
304 } else if (cmd == kCmds[6]) {
305 PrintHelp();
306 } else if (cmd == kCmds[2] || cmd == kCmds[3]) {
307 tf_stat->BuildView(cmd);
308 tf_stat->ShowMultiGraphNode(cmd, new_opts);
309 } else if (cmd == kCmds[0] || cmd == kCmds[1]) {
310 tf_stat->BuildView(cmd);
311 tf_stat->ShowGraphNode(cmd, new_opts);
312 } else if (cmd == kCmds[4]) {
313 tf_stat->BuildAllViews();
314 Advisor(tf_stat.get()).Advise(Advisor::DefaultOptions());
315 }
316 }
317 return 0;
318 }
319 } // namespace tfprof
320 } // namespace tensorflow
321
main(int argc,char ** argv)322 int main(int argc, char** argv) { return tensorflow::tfprof::Run(argc, argv); }
323