1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdio.h>
17 #include <stdlib.h>
18
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_split.h"
27 #include "linenoise.h"
28 #include "tensorflow/c/c_api.h"
29 #include "tensorflow/c/checkpoint_reader.h"
30 #include "tensorflow/core/framework/graph.pb.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/init_main.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/profiler/internal/advisor/tfprof_advisor.h"
37 #include "tensorflow/core/profiler/internal/tfprof_stats.h"
38 #include "tensorflow/core/profiler/internal/tfprof_utils.h"
39 #include "tensorflow/core/profiler/tfprof_log.pb.h"
40 #include "tensorflow/core/profiler/tfprof_options.h"
41 #include "tensorflow/core/protobuf/config.pb.h"
42 #include "tensorflow/core/util/command_line_flags.h"
43
44 namespace tensorflow {
45 namespace tfprof {
completion(const char * buf,linenoiseCompletions * lc)46 void completion(const char* buf, linenoiseCompletions* lc) {
47 string buf_str = buf;
48 if (buf_str.find(" ") == buf_str.npos) {
49 for (const char* opt : kCmds) {
50 if (string(opt).find(buf_str) == 0) {
51 linenoiseAddCompletion(lc, opt);
52 }
53 }
54 return;
55 }
56
57 string prefix;
58 int last_dash = buf_str.find_last_of(' ');
59 if (last_dash != string::npos) {
60 prefix = buf_str.substr(0, last_dash + 1);
61 buf_str = buf_str.substr(last_dash + 1, kint32max);
62 }
63 for (const char* opt : kOptions) {
64 if (string(opt).find(buf_str) == 0) {
65 linenoiseAddCompletion(lc, (prefix + opt).c_str());
66 }
67 }
68 }
69
Run(int argc,char ** argv)70 int Run(int argc, char** argv) {
71 string FLAGS_profile_path = "";
72 string FLAGS_graph_path = "";
73 string FLAGS_run_meta_path = "";
74 string FLAGS_op_log_path = "";
75 string FLAGS_checkpoint_path = "";
76 int32 FLAGS_max_depth = 10;
77 int64 FLAGS_min_bytes = 0;
78 int64 FLAGS_min_peak_bytes = 0;
79 int64 FLAGS_min_residual_bytes = 0;
80 int64 FLAGS_min_output_bytes = 0;
81 int64 FLAGS_min_micros = 0;
82 int64 FLAGS_min_accelerator_micros = 0;
83 int64 FLAGS_min_cpu_micros = 0;
84 int64 FLAGS_min_params = 0;
85 int64 FLAGS_min_float_ops = 0;
86 int64 FLAGS_min_occurrence = 0;
87 int64 FLAGS_step = -1;
88 string FLAGS_order_by = "name";
89 string FLAGS_account_type_regexes = ".*";
90 string FLAGS_start_name_regexes = ".*";
91 string FLAGS_trim_name_regexes = "";
92 string FLAGS_show_name_regexes = ".*";
93 string FLAGS_hide_name_regexes;
94 bool FLAGS_account_displayed_op_only = false;
95 string FLAGS_select = "micros";
96 string FLAGS_output = "";
97 for (int i = 0; i < argc; i++) {
98 absl::FPrintF(stderr, "%s\n", argv[i]);
99 }
100
101 std::vector<Flag> flag_list = {
102 Flag("profile_path", &FLAGS_profile_path, "Profile binary file name."),
103 Flag("graph_path", &FLAGS_graph_path, "GraphDef proto text file name"),
104 Flag("run_meta_path", &FLAGS_run_meta_path,
105 "Comma-separated list of RunMetadata proto binary "
106 "files. Each file is given step number 0,1,2,etc"),
107 Flag("op_log_path", &FLAGS_op_log_path,
108 "tensorflow::tfprof::OpLogProto proto binary file name"),
109 Flag("checkpoint_path", &FLAGS_checkpoint_path,
110 "TensorFlow Checkpoint file name"),
111 Flag("max_depth", &FLAGS_max_depth, "max depth"),
112 Flag("min_bytes", &FLAGS_min_bytes, "min_bytes"),
113 Flag("min_peak_bytes", &FLAGS_min_peak_bytes, "min_peak_bytes"),
114 Flag("min_residual_bytes", &FLAGS_min_residual_bytes,
115 "min_residual_bytes"),
116 Flag("min_output_bytes", &FLAGS_min_output_bytes, "min_output_bytes"),
117 Flag("min_micros", &FLAGS_min_micros, "min micros"),
118 Flag("min_accelerator_micros", &FLAGS_min_accelerator_micros,
119 "min accelerator_micros"),
120 Flag("min_cpu_micros", &FLAGS_min_cpu_micros, "min_cpu_micros"),
121 Flag("min_params", &FLAGS_min_params, "min params"),
122 Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"),
123 Flag("min_occurrence", &FLAGS_min_occurrence, "min occurrence"),
124 Flag("step", &FLAGS_step,
125 "The stats of which step to use. By default average"),
126 Flag("order_by", &FLAGS_order_by, "order by"),
127 Flag("account_type_regexes", &FLAGS_start_name_regexes,
128 "start name regexes"),
129 Flag("trim_name_regexes", &FLAGS_trim_name_regexes, "trim name regexes"),
130 Flag("show_name_regexes", &FLAGS_show_name_regexes, "show name regexes"),
131 Flag("hide_name_regexes", &FLAGS_hide_name_regexes, "hide name regexes"),
132 Flag("account_displayed_op_only", &FLAGS_account_displayed_op_only,
133 "account displayed op only"),
134 Flag("select", &FLAGS_select, "select"),
135 Flag("output", &FLAGS_output, "output"),
136 };
137 string usage = Flags::Usage(argv[0], flag_list);
138 bool parse_ok = Flags::Parse(&argc, argv, flag_list);
139 if (!parse_ok) {
140 absl::PrintF("%s", usage);
141 return (2);
142 }
143 port::InitMain(argv[0], &argc, &argv);
144
145 if (!FLAGS_profile_path.empty() &&
146 (!FLAGS_graph_path.empty() || !FLAGS_run_meta_path.empty())) {
147 absl::FPrintF(stderr,
148 "--profile_path is set, do not set --graph_path or "
149 "--run_meta_path\n");
150 return 1;
151 }
152
153 std::vector<string> account_type_regexes =
154 absl::StrSplit(FLAGS_account_type_regexes, ',', absl::SkipEmpty());
155 std::vector<string> start_name_regexes =
156 absl::StrSplit(FLAGS_start_name_regexes, ',', absl::SkipEmpty());
157 std::vector<string> trim_name_regexes =
158 absl::StrSplit(FLAGS_trim_name_regexes, ',', absl::SkipEmpty());
159 std::vector<string> show_name_regexes =
160 absl::StrSplit(FLAGS_show_name_regexes, ',', absl::SkipEmpty());
161 std::vector<string> hide_name_regexes =
162 absl::StrSplit(FLAGS_hide_name_regexes, ',', absl::SkipEmpty());
163 std::vector<string> select =
164 absl::StrSplit(FLAGS_select, ',', absl::SkipEmpty());
165
166 string output_type;
167 std::map<string, string> output_options;
168 Status s = ParseOutput(FLAGS_output, &output_type, &output_options);
169 CHECK(s.ok()) << s.ToString();
170
171 string cmd = "";
172 if (argc == 1 && FLAGS_graph_path.empty() && FLAGS_profile_path.empty() &&
173 FLAGS_run_meta_path.empty()) {
174 PrintHelp();
175 return 0;
176 } else if (argc > 1) {
177 if (string(argv[1]) == kCmds[6]) {
178 PrintHelp();
179 return 0;
180 }
181 if (string(argv[1]) == kCmds[0] || string(argv[1]) == kCmds[1] ||
182 string(argv[1]) == kCmds[2] || string(argv[1]) == kCmds[3] ||
183 string(argv[1]) == kCmds[4]) {
184 cmd = argv[1];
185 }
186 }
187
188 absl::PrintF("Reading Files...\n");
189 std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader;
190 TF_Status* status = TF_NewStatus();
191 if (!FLAGS_checkpoint_path.empty()) {
192 ckpt_reader.reset(
193 new checkpoint::CheckpointReader(FLAGS_checkpoint_path, status));
194 if (TF_GetCode(status) != TF_OK) {
195 absl::FPrintF(stderr, "%s\n", TF_Message(status));
196 TF_DeleteStatus(status);
197 return 1;
198 }
199 TF_DeleteStatus(status);
200 }
201
202 std::unique_ptr<TFStats> tf_stat;
203 if (!FLAGS_profile_path.empty()) {
204 tf_stat.reset(new TFStats(FLAGS_profile_path, std::move(ckpt_reader)));
205 } else {
206 absl::PrintF(
207 "Try to use a single --profile_path instead of "
208 "graph_path,op_log_path,run_meta_path\n");
209 std::unique_ptr<GraphDef> graph(new GraphDef());
210 if (!FLAGS_graph_path.empty()) {
211 s = ReadProtoFile(Env::Default(), FLAGS_graph_path, graph.get(), false);
212 if (!s.ok()) {
213 absl::FPrintF(stderr, "Failed to read graph_path: %s\n", s.ToString());
214 return 1;
215 }
216 }
217
218 std::unique_ptr<OpLogProto> op_log(new OpLogProto());
219 if (!FLAGS_op_log_path.empty()) {
220 string op_log_str;
221 s = ReadFileToString(Env::Default(), FLAGS_op_log_path, &op_log_str);
222 if (!s.ok()) {
223 absl::FPrintF(stderr, "Failed to read op_log_path: %s\n", s.ToString());
224 return 1;
225 }
226 if (!ParseProtoUnlimited(op_log.get(), op_log_str)) {
227 absl::FPrintF(stderr, "Failed to parse op_log_path\n");
228 return 1;
229 }
230 }
231 tf_stat.reset(new TFStats(std::move(graph), nullptr, std::move(op_log),
232 std::move(ckpt_reader)));
233
234 std::vector<string> run_meta_files =
235 absl::StrSplit(FLAGS_run_meta_path, ',', absl::SkipEmpty());
236 for (int i = 0; i < run_meta_files.size(); ++i) {
237 std::unique_ptr<RunMetadata> run_meta(new RunMetadata());
238 s = ReadProtoFile(Env::Default(), run_meta_files[i], run_meta.get(),
239 true);
240 if (!s.ok()) {
241 absl::FPrintF(stderr, "Failed to read run_meta_path %s. Status: %s\n",
242 run_meta_files[i], s.ToString());
243 return 1;
244 }
245 tf_stat->AddRunMeta(i, std::move(run_meta));
246 absl::FPrintF(stdout, "run graph coverage: %.2f\n",
247 tf_stat->run_coverage());
248 }
249 }
250
251 if (cmd == kCmds[4]) {
252 tf_stat->BuildAllViews();
253 Advisor(tf_stat.get()).Advise(Advisor::DefaultOptions());
254 return 0;
255 }
256
257 Options opts(
258 FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_peak_bytes,
259 FLAGS_min_residual_bytes, FLAGS_min_output_bytes, FLAGS_min_micros,
260 FLAGS_min_accelerator_micros, FLAGS_min_cpu_micros, FLAGS_min_params,
261 FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_step, FLAGS_order_by,
262 account_type_regexes, start_name_regexes, trim_name_regexes,
263 show_name_regexes, hide_name_regexes, FLAGS_account_displayed_op_only,
264 select, output_type, output_options);
265
266 if (cmd == kCmds[2] || cmd == kCmds[3]) {
267 tf_stat->BuildView(cmd);
268 tf_stat->ShowMultiGraphNode(cmd, opts);
269 return 0;
270 } else if (cmd == kCmds[0] || cmd == kCmds[1]) {
271 tf_stat->BuildView(cmd);
272 tf_stat->ShowGraphNode(cmd, opts);
273 return 0;
274 }
275
276 linenoiseSetCompletionCallback(completion);
277 linenoiseHistoryLoad(".tfprof_history.txt");
278
279 bool looped = false;
280 while (true) {
281 char* line = linenoise("tfprof> ");
282 if (line == nullptr) {
283 if (!looped) {
284 absl::FPrintF(stderr,
285 "Cannot start interative shell, "
286 "use 'bazel-bin' instead of 'bazel run'.\n");
287 }
288 break;
289 }
290 looped = true;
291 string line_s = line;
292 free(line);
293
294 if (line_s.empty()) {
295 absl::PrintF("%s", opts.ToString());
296 continue;
297 }
298 linenoiseHistoryAdd(line_s.c_str());
299 linenoiseHistorySave(".tfprof_history.txt");
300
301 Options new_opts = opts;
302 Status s = ParseCmdLine(line_s, &cmd, &new_opts);
303 if (!s.ok()) {
304 absl::FPrintF(stderr, "E: %s\n", s.ToString());
305 continue;
306 }
307 if (cmd == kCmds[5]) {
308 opts = new_opts;
309 } else if (cmd == kCmds[6]) {
310 PrintHelp();
311 } else if (cmd == kCmds[2] || cmd == kCmds[3]) {
312 tf_stat->BuildView(cmd);
313 tf_stat->ShowMultiGraphNode(cmd, new_opts);
314 } else if (cmd == kCmds[0] || cmd == kCmds[1]) {
315 tf_stat->BuildView(cmd);
316 tf_stat->ShowGraphNode(cmd, new_opts);
317 } else if (cmd == kCmds[4]) {
318 tf_stat->BuildAllViews();
319 Advisor(tf_stat.get()).Advise(Advisor::DefaultOptions());
320 }
321 }
322 return 0;
323 }
324 } // namespace tfprof
325 } // namespace tensorflow
326
main(int argc,char ** argv)327 int main(int argc, char** argv) { return tensorflow::tfprof::Run(argc, argv); }
328